1use crate::prelude_dev::*;
4
5pub fn broadcast_arrays<R, T, B>(tensors: Vec<TensorAny<R, T, B, IxD>>) -> Vec<TensorAny<R, T, B, IxD>>
13where
14 R: DataAPI<Data = B::Raw>,
15 B: DeviceAPI<T>,
16{
17 broadcast_arrays_f(tensors).unwrap()
18}
19
20pub fn broadcast_arrays_f<R, T, B>(tensors: Vec<TensorAny<R, T, B, IxD>>) -> Result<Vec<TensorAny<R, T, B, IxD>>>
21where
22 R: DataAPI<Data = B::Raw>,
23 B: DeviceAPI<T>,
24{
25 if tensors.len() <= 1 {
27 return Ok(tensors);
28 }
29 let device_b = tensors[0].device().clone();
30 let default_order = device_b.default_order();
31 let mut shape_b = tensors[0].shape().clone();
32 for tensor in tensors.iter().skip(1) {
33 rstsr_assert!(device_b.same_device(tensor.device()), DeviceMismatch)?;
34 let shape = tensor.shape();
35 let (shape, _, _) = broadcast_shape(shape, &shape_b, default_order)?;
36 shape_b = shape;
37 }
38 let mut tensors_new = Vec::with_capacity(tensors.len());
39 for tensor in tensors {
40 let tensor = into_broadcast_f(tensor, shape_b.clone())?;
41 tensors_new.push(tensor);
42 }
43 return Ok(tensors_new);
44}
45
46pub fn into_broadcast_f<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> Result<TensorAny<R, T, B, D2>>
51where
52 R: DataAPI<Data = B::Raw>,
53 B: DeviceAPI<T>,
54 D: DimAPI + DimMaxAPI<D2, Max = D2>,
55 D2: DimAPI,
56{
57 let shape1 = tensor.shape();
58 let shape2 = &shape;
59 let default_order = tensor.device().default_order();
60 let (shape, tp1, _) = broadcast_shape(shape1, shape2, default_order)?;
61 let (storage, layout) = tensor.into_raw_parts();
62 let layout = update_layout_by_shape(&layout, &shape, &tp1, default_order)?;
63 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
64}
65
66pub fn to_broadcast<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> TensorView<'_, T, B, D2>
72where
73 D: DimAPI + DimMaxAPI<D2, Max = D2>,
74 D2: DimAPI,
75 R: DataAPI<Data = B::Raw>,
76 B: DeviceAPI<T>,
77{
78 into_broadcast_f(tensor.view(), shape).unwrap()
79}
80
81pub fn to_broadcast_f<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> Result<TensorView<'_, T, B, D2>>
82where
83 D: DimAPI + DimMaxAPI<D2, Max = D2>,
84 D2: DimAPI,
85 R: DataAPI<Data = B::Raw>,
86 B: DeviceAPI<T>,
87{
88 into_broadcast_f(tensor.view(), shape)
89}
90
91pub fn into_broadcast<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> TensorAny<R, T, B, D2>
92where
93 R: DataAPI<Data = B::Raw>,
94 B: DeviceAPI<T>,
95 D: DimAPI + DimMaxAPI<D2, Max = D2>,
96 D2: DimAPI,
97{
98 into_broadcast_f(tensor, shape).unwrap()
99}
100
101impl<R, T, B, D> TensorAny<R, T, B, D>
102where
103 R: DataAPI<Data = B::Raw>,
104 B: DeviceAPI<T>,
105 D: DimAPI,
106{
107 pub fn to_broadcast<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
113 where
114 D2: DimAPI,
115 D: DimMaxAPI<D2, Max = D2>,
116 {
117 to_broadcast(self, shape)
118 }
119
120 pub fn to_broadcast_f<D2>(&self, shape: D2) -> Result<TensorView<'_, T, B, D2>>
121 where
122 D2: DimAPI,
123 D: DimMaxAPI<D2, Max = D2>,
124 {
125 to_broadcast_f(self, shape)
126 }
127
128 pub fn into_broadcast<D2>(self, shape: D2) -> TensorAny<R, T, B, D2>
134 where
135 D2: DimAPI,
136 D: DimMaxAPI<D2, Max = D2>,
137 {
138 into_broadcast(self, shape)
139 }
140
141 pub fn into_broadcast_f<D2>(self, shape: D2) -> Result<TensorAny<R, T, B, D2>>
142 where
143 D2: DimAPI,
144 D: DimMaxAPI<D2, Max = D2>,
145 {
146 into_broadcast_f(self, shape)
147 }
148}
149
150pub fn into_expand_dims_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, IxD>>
155where
156 D: DimAPI,
157 I: TryInto<AxesIndex<isize>, Error = Error>,
158{
159 let ndim: isize = TryInto::<isize>::try_into(tensor.ndim())?;
161 let (storage, layout) = tensor.into_raw_parts();
162 let mut layout = layout.into_dim::<IxD>()?;
163 let mut axes: Vec<isize> =
164 axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v - ndim - 1 } else { v }).collect::<Vec<isize>>();
165 axes.sort();
166 for &axis in axes.iter() {
167 layout = layout.dim_insert(axis)?;
168 }
169 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
170}
171
172pub fn expand_dims<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, IxD>
183where
184 D: DimAPI,
185 I: TryInto<AxesIndex<isize>, Error = Error>,
186 R: DataAPI<Data = B::Raw>,
187 B: DeviceAPI<T>,
188{
189 into_expand_dims_f(tensor.view(), axes).unwrap()
190}
191
192pub fn expand_dims_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, IxD>>
193where
194 D: DimAPI,
195 I: TryInto<AxesIndex<isize>, Error = Error>,
196 R: DataAPI<Data = B::Raw>,
197 B: DeviceAPI<T>,
198{
199 into_expand_dims_f(tensor.view(), axes)
200}
201
202pub fn into_expand_dims<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, IxD>
203where
204 D: DimAPI,
205 I: TryInto<AxesIndex<isize>, Error = Error>,
206{
207 into_expand_dims_f(tensor, axes).unwrap()
208}
209
210impl<R, T, B, D> TensorAny<R, T, B, D>
211where
212 R: DataAPI<Data = B::Raw>,
213 B: DeviceAPI<T>,
214 D: DimAPI,
215{
216 pub fn expand_dims<I>(&self, axes: I) -> TensorView<'_, T, B, IxD>
223 where
224 I: TryInto<AxesIndex<isize>, Error = Error>,
225 {
226 into_expand_dims(self.view(), axes)
227 }
228
229 pub fn expand_dims_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, IxD>>
230 where
231 I: TryInto<AxesIndex<isize>, Error = Error>,
232 {
233 into_expand_dims_f(self.view(), axes)
234 }
235
236 pub fn into_expand_dims<I>(self, axes: I) -> TensorAny<R, T, B, IxD>
243 where
244 I: TryInto<AxesIndex<isize>, Error = Error>,
245 {
246 into_expand_dims(self, axes)
247 }
248
249 pub fn into_expand_dims_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, IxD>>
250 where
251 I: TryInto<AxesIndex<isize>, Error = Error>,
252 {
253 into_expand_dims_f(self, axes)
254 }
255}
256
257pub fn into_flip_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, D>>
262where
263 D: DimAPI,
264 I: TryInto<AxesIndex<isize>, Error = Error>,
265{
266 let (storage, mut layout) = tensor.into_raw_parts();
267 let axes = axes.try_into()?;
268 match axes {
269 AxesIndex::Val(axis) => {
270 layout = layout.dim_narrow(axis, slice!(None, None, -1))?;
271 },
272 AxesIndex::Vec(axes) => {
273 for &axis in axes.iter() {
274 layout = layout.dim_narrow(axis, slice!(None, None, -1))?;
275 }
276 },
277 }
278 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
279}
280
281pub fn flip<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, D>
291where
292 D: DimAPI,
293 I: TryInto<AxesIndex<isize>, Error = Error>,
294 R: DataAPI<Data = B::Raw>,
295 B: DeviceAPI<T>,
296{
297 into_flip_f(tensor.view(), axes).unwrap()
298}
299
300pub fn flip_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, D>>
301where
302 D: DimAPI,
303 I: TryInto<AxesIndex<isize>, Error = Error>,
304 R: DataAPI<Data = B::Raw>,
305 B: DeviceAPI<T>,
306{
307 into_flip_f(tensor.view(), axes)
308}
309
310pub fn into_flip<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, D>
311where
312 D: DimAPI,
313 I: TryInto<AxesIndex<isize>, Error = Error>,
314{
315 into_flip_f(tensor, axes).unwrap()
316}
317
318impl<R, T, B, D> TensorAny<R, T, B, D>
319where
320 R: DataAPI<Data = B::Raw>,
321 B: DeviceAPI<T>,
322 D: DimAPI,
323{
324 pub fn flip<I>(&self, axis: I) -> TensorView<'_, T, B, D>
330 where
331 I: TryInto<AxesIndex<isize>, Error = Error>,
332 {
333 flip(self, axis)
334 }
335
336 pub fn flip_f<I>(&self, axis: I) -> Result<TensorView<'_, T, B, D>>
337 where
338 I: TryInto<AxesIndex<isize>, Error = Error>,
339 {
340 flip_f(self, axis)
341 }
342
343 pub fn into_flip<I>(self, axis: I) -> TensorAny<R, T, B, D>
349 where
350 I: TryInto<AxesIndex<isize>, Error = Error>,
351 {
352 into_flip(self, axis)
353 }
354
355 pub fn into_flip_f<I>(self, axis: I) -> Result<TensorAny<R, T, B, D>>
356 where
357 I: TryInto<AxesIndex<isize>, Error = Error>,
358 {
359 into_flip_f(self, axis)
360 }
361}
362
363pub fn into_transpose_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, D>>
368where
369 D: DimAPI,
370 I: TryInto<AxesIndex<isize>, Error = Error>,
371{
372 let axes = axes.try_into()?;
373 if axes.as_ref().is_empty() {
374 return Ok(into_reverse_axes(tensor));
375 }
376 let (storage, layout) = tensor.into_raw_parts();
377 let layout = layout.transpose(axes.as_ref())?;
378 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
379}
380
381pub fn transpose<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, D>
387where
388 D: DimAPI,
389 I: TryInto<AxesIndex<isize>, Error = Error>,
390 R: DataAPI<Data = B::Raw>,
391 B: DeviceAPI<T>,
392{
393 into_transpose_f(tensor.view(), axes).unwrap()
394}
395
396pub fn transpose_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, D>>
397where
398 D: DimAPI,
399 I: TryInto<AxesIndex<isize>, Error = Error>,
400 R: DataAPI<Data = B::Raw>,
401 B: DeviceAPI<T>,
402{
403 into_transpose_f(tensor.view(), axes)
404}
405
406pub fn into_transpose<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, D>
407where
408 D: DimAPI,
409 I: TryInto<AxesIndex<isize>, Error = Error>,
410{
411 into_transpose_f(tensor, axes).unwrap()
412}
413
414pub use into_transpose as into_permute_dims;
415pub use into_transpose_f as into_permute_dims_f;
416pub use transpose as permute_dims;
417pub use transpose_f as permute_dims_f;
418
419impl<R, T, B, D> TensorAny<R, T, B, D>
420where
421 R: DataAPI<Data = B::Raw>,
422 B: DeviceAPI<T>,
423 D: DimAPI,
424{
425 pub fn transpose<I>(&self, axes: I) -> TensorView<'_, T, B, D>
431 where
432 I: TryInto<AxesIndex<isize>, Error = Error>,
433 {
434 transpose(self, axes)
435 }
436
437 pub fn transpose_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, D>>
438 where
439 I: TryInto<AxesIndex<isize>, Error = Error>,
440 {
441 transpose_f(self, axes)
442 }
443
444 pub fn into_transpose<I>(self, axes: I) -> TensorAny<R, T, B, D>
450 where
451 I: TryInto<AxesIndex<isize>, Error = Error>,
452 {
453 into_transpose(self, axes)
454 }
455
456 pub fn into_transpose_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, D>>
457 where
458 I: TryInto<AxesIndex<isize>, Error = Error>,
459 {
460 into_transpose_f(self, axes)
461 }
462
463 pub fn permute_dims<I>(&self, axes: I) -> TensorView<'_, T, B, D>
469 where
470 I: TryInto<AxesIndex<isize>, Error = Error>,
471 {
472 transpose(self, axes)
473 }
474
475 pub fn permute_dims_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, D>>
476 where
477 I: TryInto<AxesIndex<isize>, Error = Error>,
478 {
479 transpose_f(self, axes)
480 }
481
482 pub fn into_permute_dims<I>(self, axes: I) -> TensorAny<R, T, B, D>
488 where
489 I: TryInto<AxesIndex<isize>, Error = Error>,
490 {
491 into_transpose(self, axes)
492 }
493
494 pub fn into_permute_dims_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, D>>
495 where
496 I: TryInto<AxesIndex<isize>, Error = Error>,
497 {
498 into_transpose_f(self, axes)
499 }
500}
501
502pub fn into_reverse_axes<S, D>(tensor: TensorBase<S, D>) -> TensorBase<S, D>
507where
508 D: DimAPI,
509{
510 let (storage, layout) = tensor.into_raw_parts();
511 let layout = layout.reverse_axes();
512 unsafe { TensorBase::new_unchecked(storage, layout) }
513}
514
515pub fn reverse_axes<R, T, B, D>(tensor: &TensorAny<R, T, B, D>) -> TensorView<'_, T, B, D>
517where
518 D: DimAPI,
519 R: DataAPI<Data = B::Raw>,
520 B: DeviceAPI<T>,
521{
522 into_reverse_axes(tensor.view())
523}
524
525impl<R, T, B, D> TensorAny<R, T, B, D>
526where
527 R: DataAPI<Data = B::Raw>,
528 B: DeviceAPI<T>,
529 D: DimAPI,
530{
531 pub fn reverse_axes(&self) -> TensorView<'_, T, B, D> {
537 into_reverse_axes(self.view())
538 }
539
540 pub fn into_reverse_axes(self) -> TensorAny<R, T, B, D> {
546 into_reverse_axes(self)
547 }
548
549 pub fn t(&self) -> TensorView<'_, T, B, D> {
555 into_reverse_axes(self.view())
556 }
557}
558
559pub fn into_swapaxes_f<I, S, D>(tensor: TensorBase<S, D>, axis1: I, axis2: I) -> Result<TensorBase<S, D>>
564where
565 D: DimAPI,
566 I: TryInto<isize>,
567{
568 let axis1 = axis1.try_into().map_err(|_| rstsr_error!(TryFromIntError))?;
569 let axis2 = axis2.try_into().map_err(|_| rstsr_error!(TryFromIntError))?;
570 let (storage, layout) = tensor.into_raw_parts();
571 let layout = layout.swapaxes(axis1, axis2)?;
572 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
573}
574
575pub fn swapaxes<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axis1: I, axis2: I) -> TensorView<'_, T, B, D>
581where
582 D: DimAPI,
583 I: TryInto<isize>,
584 R: DataAPI<Data = B::Raw>,
585 B: DeviceAPI<T>,
586{
587 into_swapaxes_f(tensor.view(), axis1, axis2).unwrap()
588}
589
590pub fn swapaxes_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axis1: I, axis2: I) -> Result<TensorView<'_, T, B, D>>
591where
592 D: DimAPI,
593 I: TryInto<isize>,
594 R: DataAPI<Data = B::Raw>,
595 B: DeviceAPI<T>,
596{
597 into_swapaxes_f(tensor.view(), axis1, axis2)
598}
599
600pub fn into_swapaxes<I, S, D>(tensor: TensorBase<S, D>, axis1: I, axis2: I) -> TensorBase<S, D>
601where
602 D: DimAPI,
603 I: TryInto<isize>,
604{
605 into_swapaxes_f(tensor, axis1, axis2).unwrap()
606}
607
608impl<R, T, B, D> TensorAny<R, T, B, D>
609where
610 R: DataAPI<Data = B::Raw>,
611 B: DeviceAPI<T>,
612 D: DimAPI,
613{
614 pub fn swapaxes<I>(&self, axis1: I, axis2: I) -> TensorView<'_, T, B, D>
620 where
621 I: TryInto<isize>,
622 {
623 swapaxes(self, axis1, axis2)
624 }
625
626 pub fn swapaxes_f<I>(&self, axis1: I, axis2: I) -> Result<TensorView<'_, T, B, D>>
627 where
628 I: TryInto<isize>,
629 {
630 swapaxes_f(self, axis1, axis2)
631 }
632
633 pub fn into_swapaxes<I>(self, axis1: I, axis2: I) -> TensorAny<R, T, B, D>
639 where
640 I: TryInto<isize>,
641 {
642 into_swapaxes(self, axis1, axis2)
643 }
644
645 pub fn into_swapaxes_f<I>(self, axis1: I, axis2: I) -> Result<TensorAny<R, T, B, D>>
646 where
647 I: TryInto<isize>,
648 {
649 into_swapaxes_f(self, axis1, axis2)
650 }
651}
652
653pub fn into_squeeze_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, IxD>>
658where
659 D: DimAPI,
660 I: TryInto<AxesIndex<isize>, Error = Error>,
661{
662 let ndim: isize = TryInto::<isize>::try_into(tensor.ndim())?;
664 let (storage, layout) = tensor.into_raw_parts();
665 let mut layout = layout.into_dim::<IxD>()?;
666 let mut axes: Vec<isize> =
667 axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect::<_>();
668 axes.sort_by(|a, b| b.cmp(a));
669 if axes.first().is_some_and(|&v| v < 0) {
670 return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
671 }
672 for i in 0..axes.len() - 1 {
674 rstsr_assert!(axes[i] != axes[i + 1], InvalidValue, "Same axes is not allowed here.")?;
675 }
676 for &axis in axes.iter() {
678 layout = layout.dim_eliminate(axis)?;
679 }
680 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
681}
682
683pub fn squeeze<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, IxD>
689where
690 D: DimAPI,
691 I: TryInto<AxesIndex<isize>, Error = Error>,
692 R: DataAPI<Data = B::Raw>,
693 B: DeviceAPI<T>,
694{
695 into_squeeze_f(tensor.view(), axes).unwrap()
696}
697
698pub fn squeeze_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, IxD>>
699where
700 D: DimAPI,
701 I: TryInto<AxesIndex<isize>, Error = Error>,
702 R: DataAPI<Data = B::Raw>,
703 B: DeviceAPI<T>,
704{
705 into_squeeze_f(tensor.view(), axes)
706}
707
708pub fn into_squeeze<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, IxD>
709where
710 D: DimAPI,
711 I: TryInto<AxesIndex<isize>, Error = Error>,
712{
713 into_squeeze_f(tensor, axes).unwrap()
714}
715
716impl<R, T, B, D> TensorAny<R, T, B, D>
717where
718 R: DataAPI<Data = B::Raw>,
719 B: DeviceAPI<T>,
720 D: DimAPI,
721{
722 pub fn squeeze<I>(&self, axis: I) -> TensorView<'_, T, B, IxD>
728 where
729 I: TryInto<AxesIndex<isize>, Error = Error>,
730 {
731 squeeze(self, axis)
732 }
733
734 pub fn squeeze_f<I>(&self, axis: I) -> Result<TensorView<'_, T, B, IxD>>
735 where
736 I: TryInto<AxesIndex<isize>, Error = Error>,
737 {
738 squeeze_f(self, axis)
739 }
740
741 pub fn into_squeeze<I>(self, axis: I) -> TensorAny<R, T, B, IxD>
747 where
748 I: TryInto<AxesIndex<isize>, Error = Error>,
749 {
750 into_squeeze(self, axis)
751 }
752
753 pub fn into_squeeze_f<I>(self, axis: I) -> Result<TensorAny<R, T, B, IxD>>
754 where
755 I: TryInto<AxesIndex<isize>, Error = Error>,
756 {
757 into_squeeze_f(self, axis)
758 }
759}
760
761pub fn into_dim_f<S, D, D2>(tensor: TensorBase<S, D>) -> Result<TensorBase<S, D2>>
766where
767 D: DimAPI + DimIntoAPI<D2>,
768 D2: DimAPI,
769{
770 let (storage, layout) = tensor.into_raw_parts();
771 let layout = layout.into_dim::<D2>()?;
772 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
773}
774
775pub fn to_dim<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>) -> TensorView<'_, T, B, D2>
780where
781 D: DimAPI,
782 D2: DimAPI,
783 D: DimIntoAPI<D2>,
784 R: DataAPI<Data = B::Raw>,
785 B: DeviceAPI<T>,
786{
787 into_dim_f(tensor.view()).unwrap()
788}
789
790pub fn to_dim_f<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>) -> Result<TensorView<'_, T, B, D2>>
791where
792 D: DimAPI,
793 D2: DimAPI,
794 D: DimIntoAPI<D2>,
795 R: DataAPI<Data = B::Raw>,
796 B: DeviceAPI<T>,
797{
798 into_dim_f(tensor.view())
799}
800
801pub fn into_dim<S, D, D2>(tensor: TensorBase<S, D>) -> TensorBase<S, D2>
802where
803 D: DimAPI,
804 D2: DimAPI,
805 D: DimIntoAPI<D2>,
806{
807 into_dim_f(tensor).unwrap()
808}
809
810pub fn to_dyn<R, T, B, D>(tensor: &TensorAny<R, T, B, D>) -> TensorView<'_, T, B, IxD>
811where
812 D: DimAPI,
813 R: DataAPI<Data = B::Raw>,
814 B: DeviceAPI<T>,
815{
816 into_dim_f(tensor.view()).unwrap()
817}
818
819pub fn into_dyn<S, D>(tensor: TensorBase<S, D>) -> TensorBase<S, IxD>
820where
821 D: DimAPI,
822{
823 into_dim_f(tensor).unwrap()
824}
825
826impl<R, T, B, D> TensorAny<R, T, B, D>
827where
828 D: DimAPI,
829 R: DataAPI<Data = B::Raw>,
830 B: DeviceAPI<T>,
831{
832 pub fn to_dim<D2>(&self) -> TensorView<'_, T, B, D2>
841 where
842 D2: DimAPI,
843 D: DimIntoAPI<D2>,
844 {
845 to_dim(self)
846 }
847
848 pub fn to_dim_f<D2>(&self) -> Result<TensorView<'_, T, B, D2>>
849 where
850 D2: DimAPI,
851 D: DimIntoAPI<D2>,
852 {
853 to_dim_f(self)
854 }
855
856 pub fn into_dim<D2>(self) -> TensorAny<R, T, B, D2>
862 where
863 D2: DimAPI,
864 D: DimIntoAPI<D2>,
865 {
866 into_dim(self)
867 }
868
869 pub fn into_dim_f<D2>(self) -> Result<TensorAny<R, T, B, D2>>
870 where
871 D2: DimAPI,
872 D: DimIntoAPI<D2>,
873 {
874 into_dim_f(self)
875 }
876
877 pub fn to_dyn(&self) -> TensorView<'_, T, B, IxD> {
879 to_dyn(self)
880 }
881
882 pub fn into_dyn(self) -> TensorAny<R, T, B, IxD> {
884 into_dyn(self)
885 }
886}
887
888pub fn into_shape_assume_contig_f<R, T, B, D, D2>(
893 tensor: TensorAny<R, T, B, D>,
894 shape: D2,
895) -> Result<TensorAny<R, T, B, D2>>
896where
897 R: DataAPI<Data = B::Raw>,
898 B: DeviceAPI<T>,
899 D: DimAPI,
900 D2: DimAPI,
901{
902 let default_order = tensor.device().default_order();
903 let (storage, layout) = tensor.into_raw_parts();
904
905 rstsr_assert_eq!(layout.size(), shape.shape_size(), InvalidLayout, "Number of elements not same.")?;
906
907 let new_layout = {
908 if default_order == FlagOrder::C && layout.c_contig() {
909 shape.new_c_contig(Some(layout.offset()))
910 } else if default_order == FlagOrder::F && layout.f_contig() {
911 shape.new_f_contig(Some(layout.offset()))
912 } else {
913 rstsr_raise!(InvalidLayout, "This array is not contiguous by {:?}", default_order)?
914 }
915 };
916 unsafe { Ok(TensorBase::new_unchecked(storage, new_layout)) }
917}
918
919pub fn to_shape_assume_contig<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> TensorView<'_, T, B, D2>
928where
929 D: DimAPI,
930 D2: DimAPI,
931 R: DataAPI<Data = B::Raw>,
932 B: DeviceAPI<T>,
933{
934 into_shape_assume_contig_f(tensor.view(), shape).unwrap()
935}
936
937pub fn to_shape_assume_contig_f<R, T, B, D, D2>(
938 tensor: &TensorAny<R, T, B, D>,
939 shape: D2,
940) -> Result<TensorView<'_, T, B, D2>>
941where
942 D: DimAPI,
943 D2: DimAPI,
944 R: DataAPI<Data = B::Raw>,
945 B: DeviceAPI<T>,
946{
947 into_shape_assume_contig_f(tensor.view(), shape)
948}
949
950pub fn into_shape_assume_contig<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> TensorAny<R, T, B, D2>
951where
952 R: DataAPI<Data = B::Raw>,
953 B: DeviceAPI<T>,
954 D: DimAPI,
955 D2: DimAPI,
956{
957 into_shape_assume_contig_f(tensor, shape).unwrap()
958}
959
960pub use to_shape_assume_contig as reshape_assume_contig;
961pub use to_shape_assume_contig_f as reshape_assume_contig_f;
962
963impl<R, T, B, D> TensorAny<R, T, B, D>
964where
965 R: DataAPI<Data = B::Raw>,
966 B: DeviceAPI<T>,
967 D: DimAPI,
968{
969 pub fn reshape_assume_contig<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
975 where
976 D2: DimAPI,
977 {
978 into_shape_assume_contig(self.view(), shape)
979 }
980
981 pub fn reshape_assume_contig_f<D2>(&self, shape: D2) -> Result<TensorView<'_, T, B, D2>>
982 where
983 D2: DimAPI,
984 {
985 into_shape_assume_contig_f(self.view(), shape)
986 }
987
988 pub fn to_shape_assume_contig<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
989 where
990 D2: DimAPI,
991 {
992 into_shape_assume_contig(self.view(), shape)
993 }
994
995 pub fn to_shape_assume_contig_f<D2>(&self, shape: D2) -> Result<TensorView<'_, T, B, D2>>
996 where
997 D2: DimAPI,
998 {
999 into_shape_assume_contig_f(self.view(), shape)
1000 }
1001
1002 pub fn into_shape_assume_contig<D2>(self, shape: D2) -> TensorAny<R, T, B, D2>
1003 where
1004 D2: DimAPI,
1005 {
1006 into_shape_assume_contig(self, shape)
1007 }
1008
1009 pub fn into_shape_assume_contig_f<D2>(self, shape: D2) -> Result<TensorAny<R, T, B, D2>>
1010 where
1011 D2: DimAPI,
1012 {
1013 into_shape_assume_contig_f(self, shape)
1014 }
1015}
1016
1017pub fn change_shape_f<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
1022where
1023 I: TryInto<AxesIndex<isize>, Error = Error>,
1024 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1025 D: DimAPI,
1026 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
1027{
1028 let shape_new = reshape_substitute_negatives(shape.try_into()?.as_ref(), tensor.size())?;
1030 let default_order = tensor.device().default_order();
1031 if let Some(layout_new) = layout_reshapeable(&tensor.layout().to_dim()?, &shape_new, default_order)? {
1032 let (storage, _) = tensor.into_raw_parts();
1034 let layout = layout_new.into_dim::<IxD>()?;
1035 return unsafe { Ok(TensorBase::new_unchecked(storage, layout).into_cow()) };
1036 } else {
1037 let (storage, layout) = tensor.into_raw_parts();
1039 let device = storage.device();
1040 let layout_new = match default_order {
1041 RowMajor => shape_new.new_c_contig(None),
1042 ColMajor => shape_new.new_f_contig(None),
1043 };
1044 let mut storage_new = unsafe { device.empty_impl(layout_new.size())? };
1045 device.assign_arbitary(storage_new.raw_mut(), &layout_new, storage.raw(), &layout)?;
1046 return unsafe { Ok(TensorBase::new_unchecked(storage_new, layout_new).into_cow()) };
1047 }
1048}
1049
1050pub fn change_shape<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> TensorCow<'a, T, B, IxD>
1051where
1052 I: TryInto<AxesIndex<isize>, Error = Error>,
1053 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1054 D: DimAPI,
1055 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
1056{
1057 change_shape_f(tensor, shape).unwrap()
1058}
1059
1060pub fn into_shape_f<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> Result<Tensor<T, B, IxD>>
1061where
1062 I: TryInto<AxesIndex<isize>, Error = Error>,
1063 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1064 D: DimAPI,
1065 T: Clone,
1066 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D> + OpAssignAPI<T, IxD>,
1067 B::Raw: Clone + 'a,
1068{
1069 change_shape_f(tensor, shape).map(|v| v.into_owned())
1070}
1071
1072pub fn into_shape<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> Tensor<T, B, IxD>
1073where
1074 I: TryInto<AxesIndex<isize>, Error = Error>,
1075 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1076 D: DimAPI,
1077 T: Clone,
1078 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D> + OpAssignAPI<T, IxD>,
1079 B::Raw: Clone + 'a,
1080{
1081 into_shape_f(tensor, shape).unwrap()
1082}
1083
1084pub fn to_shape_f<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
1085where
1086 I: TryInto<AxesIndex<isize>, Error = Error>,
1087 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1088 D: DimAPI,
1089 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
1090{
1091 change_shape_f(tensor.view(), shape)
1092}
1093
1094pub fn to_shape<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> TensorCow<'a, T, B, IxD>
1095where
1096 I: TryInto<AxesIndex<isize>, Error = Error>,
1097 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1098 D: DimAPI,
1099 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
1100{
1101 to_shape_f(tensor, shape).unwrap()
1102}
1103
1104pub fn reshape_f<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
1105where
1106 I: TryInto<AxesIndex<isize>, Error = Error>,
1107 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1108 D: DimAPI,
1109 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
1110{
1111 to_shape_f(tensor, shape)
1112}
1113
1114pub fn reshape<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> TensorCow<'a, T, B, IxD>
1115where
1116 I: TryInto<AxesIndex<isize>, Error = Error>,
1117 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1118 D: DimAPI,
1119 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
1120{
1121 to_shape(tensor, shape)
1122}
1123
1124impl<'a, R, T, B, D> TensorAny<R, T, B, D>
1125where
1126 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1127 D: DimAPI,
1128 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D> + OpAssignAPI<T, IxD>,
1129 T: Clone,
1130{
1131 pub fn change_shape_f<I>(self, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
1132 where
1133 I: TryInto<AxesIndex<isize>, Error = Error>,
1134 {
1135 change_shape_f(self, shape)
1136 }
1137
1138 pub fn change_shape<I>(self, shape: I) -> TensorCow<'a, T, B, IxD>
1139 where
1140 I: TryInto<AxesIndex<isize>, Error = Error>,
1141 {
1142 change_shape(self, shape)
1143 }
1144
1145 pub fn into_shape_f<I>(self, shape: I) -> Result<Tensor<T, B, IxD>>
1146 where
1147 I: TryInto<AxesIndex<isize>, Error = Error>,
1148 B::Raw: Clone + 'a,
1149 {
1150 into_shape_f(self, shape)
1151 }
1152
1153 pub fn into_shape<I>(self, shape: I) -> Tensor<T, B, IxD>
1154 where
1155 I: TryInto<AxesIndex<isize>, Error = Error>,
1156 B::Raw: Clone + 'a,
1157 {
1158 into_shape(self, shape)
1159 }
1160
1161 pub fn to_shape_f<I>(&'a self, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
1162 where
1163 I: TryInto<AxesIndex<isize>, Error = Error>,
1164 {
1165 self.view().change_shape_f(shape)
1166 }
1167
1168 pub fn to_shape<I>(&'a self, shape: I) -> TensorCow<'a, T, B, IxD>
1169 where
1170 I: TryInto<AxesIndex<isize>, Error = Error>,
1171 {
1172 self.view().change_shape(shape)
1173 }
1174
1175 pub fn reshape_f<I>(&'a self, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
1176 where
1177 I: TryInto<AxesIndex<isize>, Error = Error>,
1178 {
1179 self.view().change_shape_f(shape)
1180 }
1181
1182 pub fn reshape<I>(&'a self, shape: I) -> TensorCow<'a, T, B, IxD>
1183 where
1184 I: TryInto<AxesIndex<isize>, Error = Error>,
1185 {
1186 self.view().change_shape(shape)
1187 }
1188}
1189
1190pub fn change_layout_f<'a, R, T, B, D, D2>(
1195 tensor: TensorAny<R, T, B, D>,
1196 layout: Layout<D2>,
1197) -> Result<TensorCow<'a, T, B, D2>>
1198where
1199 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1200 D: DimAPI,
1201 D2: DimAPI,
1202 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D>,
1203{
1204 let shape = layout.shape();
1205 rstsr_assert_eq!(tensor.size(), shape.shape_size(), InvalidLayout)?;
1206 let same_layout = tensor.layout().to_dim::<IxD>()? == layout.to_dim::<IxD>()?;
1207 let contig_c = tensor.c_contig() && layout.c_contig() && tensor.layout().offset() == layout.offset();
1208 let contig_f = tensor.f_contig() && layout.f_contig() && tensor.layout().offset() == layout.offset();
1209 let default_order = tensor.device().default_order();
1210 let contig = match default_order {
1211 RowMajor => contig_c,
1212 ColMajor => contig_f,
1213 };
1214 if same_layout || contig {
1215 let (storage, _) = tensor.into_raw_parts();
1217 let tensor = unsafe { TensorBase::new_unchecked(storage, layout) };
1218 return Ok(tensor.into_cow());
1219 } else {
1220 let (storage_old, layout_old) = tensor.into_raw_parts();
1223 let device = storage_old.device();
1224 let (_, idx_max) = layout.bounds_index()?;
1225 let mut storage_new = unsafe { device.empty_impl(idx_max)? };
1226 device.assign_arbitary(storage_new.raw_mut(), &layout, storage_old.raw(), &layout_old)?;
1227 let tensor = unsafe { TensorBase::new_unchecked(storage_new, layout) };
1228 return Ok(tensor.into_cow());
1229 }
1230}
1231
1232pub fn to_layout<R, T, D, B, D2>(tensor: &TensorAny<R, T, B, D>, layout: Layout<D2>) -> TensorCow<'_, T, B, D2>
1234where
1235 R: DataAPI<Data = B::Raw>,
1236 D: DimAPI,
1237 D2: DimAPI,
1238 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D>,
1239{
1240 change_layout_f(tensor.view(), layout).unwrap()
1241}
1242
1243pub fn to_layout_f<R, T, D, B, D2>(
1244 tensor: &TensorAny<R, T, B, D>,
1245 layout: Layout<D2>,
1246) -> Result<TensorCow<'_, T, B, D2>>
1247where
1248 R: DataAPI<Data = B::Raw>,
1249 D: DimAPI,
1250 D2: DimAPI,
1251 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D>,
1252{
1253 change_layout_f(tensor.view(), layout)
1254}
1255
1256pub fn into_layout_f<'a, R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, layout: Layout<D2>) -> Result<Tensor<T, B, D2>>
1257where
1258 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1259 D: DimAPI,
1260 D2: DimAPI,
1261 T: Clone,
1262 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D> + OpAssignAPI<T, D2>,
1263 B::Raw: Clone + 'a,
1264{
1265 change_layout_f(tensor, layout).map(|v| v.into_owned())
1266}
1267
1268pub fn into_layout<'a, R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, layout: Layout<D2>) -> Tensor<T, B, D2>
1269where
1270 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1271 D: DimAPI,
1272 D2: DimAPI,
1273 T: Clone,
1274 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D> + OpAssignAPI<T, D2>,
1275 B::Raw: Clone + 'a,
1276{
1277 into_layout_f(tensor, layout).unwrap()
1278}
1279
1280pub fn change_layout<'a, R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, layout: Layout<D2>) -> TensorCow<'a, T, B, D2>
1281where
1282 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1283 D: DimAPI,
1284 D2: DimAPI,
1285 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D>,
1286{
1287 change_layout_f(tensor, layout).unwrap()
1288}
1289
1290impl<'a, R, T, B, D> TensorAny<R, T, B, D>
1291where
1292 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1293 D: DimAPI,
1294 T: Clone,
1295 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
1296{
1297 pub fn to_layout<D2>(&self, layout: Layout<D2>) -> TensorCow<'_, T, B, D2>
1303 where
1304 D2: DimAPI,
1305 B: OpAssignArbitaryAPI<T, D2, D>,
1306 {
1307 to_layout(self, layout)
1308 }
1309
1310 pub fn to_layout_f<D2>(&self, layout: Layout<D2>) -> Result<TensorCow<'_, T, B, D2>>
1311 where
1312 D2: DimAPI,
1313 B: OpAssignArbitaryAPI<T, D2, D>,
1314 {
1315 to_layout_f(self, layout)
1316 }
1317
1318 pub fn into_layout_f<D2>(self, layout: Layout<D2>) -> Result<Tensor<T, B, D2>>
1319 where
1320 D2: DimAPI,
1321 B: OpAssignArbitaryAPI<T, D2, D> + OpAssignAPI<T, D2>,
1322 B::Raw: Clone + 'a,
1323 {
1324 into_layout_f(self, layout)
1325 }
1326
1327 pub fn into_layout<D2>(self, layout: Layout<D2>) -> Tensor<T, B, D2>
1328 where
1329 D2: DimAPI,
1330 B: OpAssignArbitaryAPI<T, D2, D> + OpAssignAPI<T, D2>,
1331 B::Raw: Clone + 'a,
1332 {
1333 into_layout(self, layout)
1334 }
1335
1336 pub fn change_layout_f<D2>(self, layout: Layout<D2>) -> Result<TensorCow<'a, T, B, D2>>
1337 where
1338 D2: DimAPI,
1339 B: OpAssignArbitaryAPI<T, D2, D>,
1340 {
1341 change_layout_f(self, layout)
1342 }
1343
1344 pub fn change_layout<D2>(self, layout: Layout<D2>) -> TensorCow<'a, T, B, D2>
1345 where
1346 D2: DimAPI,
1347 B: OpAssignArbitaryAPI<T, D2, D>,
1348 {
1349 change_layout(self, layout)
1350 }
1351}
1352
1353pub fn change_contig_f<'a, R, T, B, D>(
1358 tensor: TensorAny<R, T, B, D>,
1359 order: FlagOrder,
1360) -> Result<TensorCow<'a, T, B, D>>
1361where
1362 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1363 D: DimAPI,
1364 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1365{
1366 let shape = tensor.shape();
1367 let layout_new = match order {
1368 RowMajor => shape.new_c_contig(None),
1369 ColMajor => shape.new_f_contig(None),
1370 };
1371 change_layout_f(tensor, layout_new)
1372}
1373
1374pub fn to_contig_f<R, T, B, D>(tensor: &TensorAny<R, T, B, D>, order: FlagOrder) -> Result<TensorCow<'_, T, B, D>>
1375where
1376 R: DataAPI<Data = B::Raw>,
1377 D: DimAPI,
1378 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1379{
1380 change_contig_f(tensor.view(), order)
1381}
1382
1383pub fn into_contig_f<'a, R, T, B, D>(tensor: TensorAny<R, T, B, D>, order: FlagOrder) -> Result<Tensor<T, B, D>>
1384where
1385 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1386 D: DimAPI,
1387 T: Clone,
1388 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1389 B::Raw: Clone + 'a,
1390{
1391 change_contig_f(tensor, order).map(|v| v.into_owned())
1392}
1393
1394pub fn change_contig<'a, R, T, B, D>(tensor: TensorAny<R, T, B, D>, order: FlagOrder) -> TensorCow<'a, T, B, D>
1395where
1396 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1397 D: DimAPI,
1398 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1399{
1400 change_contig_f(tensor, order).unwrap()
1401}
1402
1403pub fn to_contig<R, T, B, D>(tensor: &TensorAny<R, T, B, D>, order: FlagOrder) -> TensorCow<'_, T, B, D>
1404where
1405 R: DataAPI<Data = B::Raw>,
1406 D: DimAPI,
1407 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1408{
1409 to_contig_f(tensor, order).unwrap()
1410}
1411
1412pub fn into_contig<'a, R, T, B, D>(tensor: TensorAny<R, T, B, D>, order: FlagOrder) -> Tensor<T, B, D>
1413where
1414 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1415 D: DimAPI,
1416 T: Clone,
1417 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1418 B::Raw: Clone + 'a,
1419{
1420 into_contig_f(tensor, order).unwrap()
1421}
1422
1423impl<'a, R, T, B, D> TensorAny<R, T, B, D>
1424where
1425 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1426 D: DimAPI,
1427 T: Clone,
1428 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
1429{
1430 pub fn to_contig(&self, order: FlagOrder) -> TensorCow<'_, T, B, D>
1432 where
1433 B: OpAssignArbitaryAPI<T, D, D>,
1434 {
1435 to_contig(self, order)
1436 }
1437
1438 pub fn to_contig_f(&self, order: FlagOrder) -> Result<TensorCow<'_, T, B, D>>
1439 where
1440 B: OpAssignArbitaryAPI<T, D, D>,
1441 {
1442 to_contig_f(self, order)
1443 }
1444
1445 pub fn into_contig_f(self, order: FlagOrder) -> Result<Tensor<T, B, D>>
1446 where
1447 B: OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1448 B::Raw: Clone + 'a,
1449 {
1450 into_contig_f(self, order)
1451 }
1452
1453 pub fn into_contig(self, order: FlagOrder) -> Tensor<T, B, D>
1454 where
1455 B: OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1456 B::Raw: Clone + 'a,
1457 {
1458 into_contig(self, order)
1459 }
1460
1461 pub fn change_contig_f(self, order: FlagOrder) -> Result<TensorCow<'a, T, B, D>>
1462 where
1463 B: OpAssignArbitaryAPI<T, D, D>,
1464 {
1465 change_contig_f(self, order)
1466 }
1467
1468 pub fn change_contig(self, order: FlagOrder) -> TensorCow<'a, T, B, D>
1469 where
1470 B: OpAssignArbitaryAPI<T, D, D>,
1471 {
1472 change_contig(self, order)
1473 }
1474}
1475
1476pub fn change_prefer_f<'a, R, T, B, D>(
1481 tensor: TensorAny<R, T, B, D>,
1482 order: FlagOrder,
1483) -> Result<TensorCow<'a, T, B, D>>
1484where
1485 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1486 D: DimAPI,
1487 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1488{
1489 if (order == FlagOrder::C && tensor.c_prefer()) || (order == FlagOrder::F && tensor.f_prefer()) {
1490 Ok(tensor.into_cow())
1491 } else {
1492 change_contig_f(tensor, order)
1493 }
1494}
1495
1496pub fn to_prefer_f<R, T, B, D>(tensor: &TensorAny<R, T, B, D>, order: FlagOrder) -> Result<TensorCow<'_, T, B, D>>
1497where
1498 R: DataAPI<Data = B::Raw>,
1499 D: DimAPI,
1500 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1501{
1502 change_prefer_f(tensor.view(), order)
1503}
1504
1505pub fn into_prefer_f<'a, R, T, B, D>(tensor: TensorAny<R, T, B, D>, order: FlagOrder) -> Result<Tensor<T, B, D>>
1506where
1507 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1508 D: DimAPI,
1509 T: Clone,
1510 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1511 B::Raw: Clone + 'a,
1512{
1513 change_prefer_f(tensor, order).map(|v| v.into_owned())
1514}
1515
1516pub fn change_prefer<'a, R, T, B, D>(tensor: TensorAny<R, T, B, D>, order: FlagOrder) -> TensorCow<'a, T, B, D>
1517where
1518 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1519 D: DimAPI,
1520 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1521{
1522 change_prefer_f(tensor, order).unwrap()
1523}
1524
1525pub fn to_prefer<R, T, B, D>(tensor: &TensorAny<R, T, B, D>, order: FlagOrder) -> TensorCow<'_, T, B, D>
1526where
1527 R: DataAPI<Data = B::Raw>,
1528 D: DimAPI,
1529 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1530{
1531 to_prefer_f(tensor, order).unwrap()
1532}
1533
1534pub fn into_prefer<'a, R, T, B, D>(tensor: TensorAny<R, T, B, D>, order: FlagOrder) -> Tensor<T, B, D>
1535where
1536 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1537 D: DimAPI,
1538 T: Clone,
1539 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1540 B::Raw: Clone + 'a,
1541{
1542 into_prefer_f(tensor, order).unwrap()
1543}
1544
1545impl<'a, R, T, B, D> TensorAny<R, T, B, D>
1546where
1547 R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1548 D: DimAPI,
1549 T: Clone,
1550 B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
1551{
1552 pub fn to_prefer(&self, order: FlagOrder) -> TensorCow<'_, T, B, D>
1554 where
1555 B: OpAssignArbitaryAPI<T, D, D>,
1556 {
1557 to_prefer(self, order)
1558 }
1559
1560 pub fn to_prefer_f(&self, order: FlagOrder) -> Result<TensorCow<'_, T, B, D>>
1561 where
1562 B: OpAssignArbitaryAPI<T, D, D>,
1563 {
1564 to_prefer_f(self, order)
1565 }
1566
1567 pub fn into_prefer_f(self, order: FlagOrder) -> Result<Tensor<T, B, D>>
1568 where
1569 B: OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1570 B::Raw: Clone + 'a,
1571 {
1572 into_prefer_f(self, order)
1573 }
1574
1575 pub fn into_prefer(self, order: FlagOrder) -> Tensor<T, B, D>
1576 where
1577 B: OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1578 B::Raw: Clone + 'a,
1579 {
1580 into_prefer(self, order)
1581 }
1582
1583 pub fn change_prefer_f(self, order: FlagOrder) -> Result<TensorCow<'a, T, B, D>>
1584 where
1585 B: OpAssignArbitaryAPI<T, D, D>,
1586 {
1587 change_prefer_f(self, order)
1588 }
1589
1590 pub fn change_prefer(self, order: FlagOrder) -> TensorCow<'a, T, B, D>
1591 where
1592 B: OpAssignArbitaryAPI<T, D, D>,
1593 {
1594 change_prefer(self, order)
1595 }
1596}
1597
1598#[cfg(test)]
1601mod test_reshape {
1602 use super::*;
1603
1604 #[test]
1605 fn test_playground() {
1606 #[cfg(not(feature = "col_major"))]
1607 {
1608 let a1 = linspace((1.0, 24.0, 24));
1609 let a2 = a1.to_shape([2, 3, 4]);
1610 let default_order = a1.device().default_order();
1611 println!("{a2:?}");
1612 println!("{:?}", core::ptr::eq(a1.as_ptr(), a2.as_ptr()));
1613
1614 let v = layout_reshapeable(a1.layout(), &vec![2, 3, 4], default_order).unwrap();
1615 println!("{v:?}");
1616
1617 let b1 = linspace((1.0, 24.0, 24)).into_layout(vec![2, 3, 4].f());
1618 let b2 = b1.to_shape([24]);
1619 println!("{b2:?}");
1620 println!("{:?}", core::ptr::eq(b1.as_ptr(), b2.as_ptr()));
1621
1622 let v = layout_reshapeable(b1.layout(), &vec![24], default_order).unwrap();
1623 println!("{v:?}");
1624 }
1625 #[cfg(feature = "col_major")]
1626 {
1627 let a1 = linspace((1.0, 24.0, 24));
1628 let a2 = a1.to_shape([2, 3, 4]);
1629 let default_order = a1.device().default_order();
1630 println!("{a2:?}");
1631 println!("{:?}", core::ptr::eq(a1.as_ptr(), a2.as_ptr()));
1632 println!("a2[:, :, 0] =\n{:}", a2.i((.., .., 0)));
1633 println!("a2[:, :, 1] =\n{:}", a2.i((.., .., 1)));
1634 println!("a2[:, :, 2] =\n{:}", a2.i((.., .., 2)));
1635 println!("a2[:, :, 3] =\n{:}", a2.i((.., .., 3)));
1636
1637 let v = layout_reshapeable(a1.layout(), &vec![2, 3, 4], default_order).unwrap();
1638 println!("{v:?}");
1639
1640 let b1 = linspace((1.0, 24.0, 24)).into_layout(vec![2, 3, 4].f());
1641 let b2 = b1.to_shape([24]);
1642 println!("{b2:?}");
1643 println!("{:?}", core::ptr::eq(b1.as_ptr(), b2.as_ptr()));
1644
1645 let v = layout_reshapeable(b1.layout(), &vec![24], default_order).unwrap();
1646 println!("{v:?}");
1647 }
1648 }
1649
1650 #[test]
1651 fn test_contig() {
1652 #[cfg(not(feature = "col_major"))]
1653 {
1654 let layout_in = vec![2, 3, 4].c();
1655 let default_order = RowMajor;
1656 let layout_out = layout_reshapeable(&layout_in, &vec![2, 3, 4], default_order).unwrap();
1657 assert_eq!(layout_out.unwrap(), vec![2, 3, 4].c());
1658
1659 let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 4], default_order).unwrap();
1660 assert_eq!(layout_out.unwrap(), vec![3, 2, 4].c());
1661
1662 let layout_out = layout_reshapeable(&layout_in, &vec![1, 4, 1, 6], default_order).unwrap();
1663 assert_eq!(layout_out.unwrap(), vec![1, 4, 1, 6].c());
1664 }
1665 #[cfg(feature = "col_major")]
1666 {
1667 let layout_in = vec![2, 3, 4].f();
1668 let default_order = ColMajor;
1669 let layout_out = layout_reshapeable(&layout_in, &vec![2, 3, 4], default_order).unwrap();
1670 assert_eq!(layout_out.unwrap(), vec![2, 3, 4].f());
1671
1672 let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 4], default_order).unwrap();
1673 assert_eq!(layout_out.unwrap(), vec![3, 2, 4].f());
1674
1675 let layout_out = layout_reshapeable(&layout_in, &vec![1, 4, 1, 6], default_order).unwrap();
1676 assert_eq!(layout_out.unwrap(), vec![1, 4, 1, 6].f());
1677 }
1678 }
1679
1680 #[test]
1681 fn test_partial_contig() {
1682 #[cfg(not(feature = "col_major"))]
1683 {
1684 let layout_in = Layout::new(vec![9, 15, 6], vec![270, 18, 3], 810).unwrap();
1687 let default_order = RowMajor;
1688
1689 let layout_out = layout_reshapeable(&layout_in, &vec![15, 9, 2, 3], default_order).unwrap();
1690 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![15, 9, 2, 3]);
1691 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![162, 18, 9, 3]);
1692 assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
1693
1694 let layout_out = layout_reshapeable(&layout_in, &vec![10, 27, 3], default_order).unwrap();
1695 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![10, 27, 3]);
1696 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![243, 9, 3]);
1697 assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
1698
1699 let layout_out = layout_reshapeable(&layout_in, &vec![1, 10, 1, 27, 3], default_order).unwrap();
1701 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![1, 10, 1, 27, 3]);
1702 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![2430, 243, 243, 9, 3]);
1704
1705 let layout_in = Layout::new(vec![9, 15, 6], vec![270, 18, 2], 813).unwrap();
1708
1709 let layout_out = layout_reshapeable(&layout_in, &vec![15, 9, 2, 3], default_order).unwrap();
1710 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![15, 9, 2, 3]);
1711 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![162, 18, 6, 2]);
1712 assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
1713
1714 let layout_out = layout_reshapeable(&layout_in, &vec![10, 27, 3], default_order).unwrap();
1715 assert!(layout_out.is_none());
1716 }
1717 #[cfg(feature = "col_major")]
1718 {
1719 let layout_in = Layout::new(vec![6, 15, 9], vec![3, 18, 270], 810).unwrap();
1720 let default_order = ColMajor;
1721
1722 let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 9, 15], default_order).unwrap();
1723 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 2, 9, 15]);
1724 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![3, 9, 18, 162]);
1725 assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
1726
1727 let layout_out = layout_reshapeable(&layout_in, &vec![3, 27, 10], default_order).unwrap();
1728 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 27, 10]);
1729 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![3, 9, 243]);
1730 assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
1731
1732 let layout_out = layout_reshapeable(&layout_in, &vec![3, 27, 1, 10, 1], default_order).unwrap();
1734 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 27, 1, 10, 1]);
1735 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![3, 9, 243, 243, 2430]);
1737
1738 let layout_in = Layout::new(vec![6, 15, 9], vec![2, 18, 270], 813).unwrap();
1741
1742 let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 9, 15], default_order).unwrap();
1743 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 2, 9, 15]);
1744 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![2, 6, 18, 162]);
1745 assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
1746
1747 let layout_out = layout_reshapeable(&layout_in, &vec![10, 27, 3], default_order).unwrap();
1748 assert!(layout_out.is_none());
1749 }
1750 }
1751
1752 #[test]
1753 fn test_minus_stride() {
1754 #[cfg(not(feature = "col_major"))]
1755 {
1756 let layout_in = Layout::new(vec![9, 15, 6], vec![270, -18, -3], 1079).unwrap();
1759 let default_order = RowMajor;
1760
1761 let layout_out = layout_reshapeable(&layout_in, &vec![15, 9, 2, 3], default_order).unwrap();
1762 assert!(layout_out.is_none());
1763
1764 let layout_out = layout_reshapeable(&layout_in, &vec![3, 3, 10, 9], default_order).unwrap();
1765 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 3, 10, 9]);
1766 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![810, 270, -27, -3]);
1767 }
1768 }
1769
1770 #[test]
1771 fn test_broadcast_reshape() {
1772 #[cfg(not(feature = "col_major"))]
1773 {
1774 let layout_in = unsafe { Layout::new_unchecked(vec![12, 16, 15, 18], vec![270, 0, 18, 1], 0) };
1777 let default_order = RowMajor;
1778
1779 let layout_out = layout_reshapeable(&layout_in, &vec![4, 3, 4, 4, 9, 1, 30], default_order).unwrap();
1780 assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![4, 3, 4, 4, 9, 1, 30]);
1781 assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![810, 270, 0, 0, 30, 30, 1]);
1782
1783 let layout_out = layout_reshapeable(&layout_in, &vec![16, 12, 15, 18], default_order).unwrap();
1784 assert!(layout_out.is_none());
1785 }
1786 }
1787}
1788
1789#[cfg(test)]
1790mod tests {
1791 use super::*;
1792 use crate::Tensor;
1793
1794 #[test]
1795 fn test_to_shape_assume_contig() {
1796 let a = linspace((2.5, 3.2, 16));
1797 let b = a.to_shape_assume_contig_f([4, 4]).unwrap();
1798 println!("{b:.3?}");
1799 }
1800
1801 #[test]
1802 fn test_expand_dims() {
1803 let a: Tensor<f64, _> = zeros([4, 9, 8]);
1804 let b = a.expand_dims(2);
1805 assert_eq!(b.shape(), &[4, 9, 1, 8]);
1806 let b = a.expand_dims([1, 3]);
1807 assert_eq!(b.shape(), &[4, 1, 9, 8, 1]);
1808 let b = a.expand_dims([1, -1]);
1809 assert_eq!(b.shape(), &[4, 1, 9, 8, 1]);
1810 let b = a.expand_dims([-1, -4, 1, 0]);
1811 assert_eq!(b.shape(), &[1, 1, 4, 1, 9, 8, 1]);
1812 }
1813
1814 #[test]
1815 fn test_squeeze() {
1816 let a: Tensor<f64, _> = zeros([4, 1, 9, 1, 8, 1]);
1817 let b = a.squeeze(3);
1818 assert_eq!(b.shape(), &[4, 1, 9, 8, 1]);
1819 let b = a.squeeze([1, 3]);
1820 assert_eq!(b.shape(), &[4, 9, 8, 1]);
1821 let b = a.squeeze([1, -1]);
1822 assert_eq!(b.shape(), &[4, 9, 1, 8]);
1823 let b = a.squeeze_f(-7);
1824 assert!(b.is_err());
1825 }
1826
1827 #[test]
1828 fn test_flip() {
1829 let a = arange(24.0).into_shape([2, 3, 4]).into_owned();
1830 println!("{a:?}");
1831
1832 let b = a.flip(1);
1833 println!("{b:?}");
1834 assert_eq!(b.shape(), &[2, 3, 4]);
1835 let c = a.flip([0, -1]);
1836 println!("{c:?}");
1837 assert_eq!(c.shape(), &[2, 3, 4]);
1838 }
1839
1840 #[test]
1841 fn test_swapaxes() {
1842 let a = arange(24.0).into_shape([2, 3, 4]).into_owned();
1843 println!("{a:?}");
1844
1845 let b = a.swapaxes(0, 1);
1846 println!("{b:?}");
1847 assert_eq!(b.shape(), &[3, 2, 4]);
1848 }
1849
1850 #[test]
1851 fn test_to_shape() {
1852 let a = linspace((0.0, 15.0, 16));
1853 let mut a = a.to_shape([4, 4]);
1854 a.layout = Layout::new(vec![2, 2], vec![2, 4], 0).unwrap();
1855 println!("{a:?}");
1856 let b = a.to_shape([2, 2]);
1857 println!("{b:?}");
1858
1859 let c = a.to_shape([2, -1]);
1860 println!("{c:?}");
1861 assert_eq!(c.shape(), &[2, 2]);
1862
1863 let d = a.to_shape_f([3, -1]);
1864 assert!(d.is_err());
1865 }
1866
1867 #[test]
1868 fn test_broadcast_to() {
1869 #[cfg(not(feature = "col_major"))]
1870 {
1871 let a = linspace((0.0, 15.0, 16));
1872 let a = a.into_shape_assume_contig_f([4, 1, 4]).unwrap();
1873 let a = a.to_broadcast_f([6, 4, 3, 4]).unwrap();
1874 println!("{a:?}");
1875 assert_eq!(a.layout(), unsafe { &Layout::new_unchecked([6, 4, 3, 4], [0, 4, 0, 1], 0) });
1876 }
1877 #[cfg(feature = "col_major")]
1878 {
1879 let a = linspace((0.0, 15.0, 16));
1880 let a = a.into_shape_assume_contig_f([4, 1, 4]).unwrap();
1881 let a = a.to_broadcast_f([4, 3, 4, 6]).unwrap();
1882 println!("{a:?}");
1883 assert_eq!(a.layout(), unsafe { &Layout::new_unchecked([4, 3, 4, 6], [1, 0, 4, 0], 0) });
1884 }
1885 }
1886
1887 #[test]
1888 fn test_to_layout() {
1889 let a = linspace((0.0, 15.0, 16));
1890 let a = a.change_shape([4, 4]);
1891 let a = a.into_layout(Layout::new([2, 8], [12, 120], 8).unwrap());
1892 println!("{a:?}");
1893 }
1894}