rstsr_core/tensor/
manuplication.rs

1//! This module handles tensor data manipulation.
2
3use crate::prelude_dev::*;
4
5/* #region broadcast_arrays */
6
7/// Broadcasts any number of arrays against each other.
8///
9/// # See also
10///
11/// [Python Array API standard: `broadcast_arrays`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.broadcast_arrays.html)
12pub fn broadcast_arrays<R, T, B>(tensors: Vec<TensorAny<R, T, B, IxD>>) -> Vec<TensorAny<R, T, B, IxD>>
13where
14    R: DataAPI<Data = B::Raw>,
15    B: DeviceAPI<T>,
16{
17    broadcast_arrays_f(tensors).unwrap()
18}
19
20pub fn broadcast_arrays_f<R, T, B>(tensors: Vec<TensorAny<R, T, B, IxD>>) -> Result<Vec<TensorAny<R, T, B, IxD>>>
21where
22    R: DataAPI<Data = B::Raw>,
23    B: DeviceAPI<T>,
24{
25    // fast return if there is only zero/one tensor
26    if tensors.len() <= 1 {
27        return Ok(tensors);
28    }
29    let device_b = tensors[0].device().clone();
30    let default_order = device_b.default_order();
31    let mut shape_b = tensors[0].shape().clone();
32    for tensor in tensors.iter().skip(1) {
33        rstsr_assert!(device_b.same_device(tensor.device()), DeviceMismatch)?;
34        let shape = tensor.shape();
35        let (shape, _, _) = broadcast_shape(shape, &shape_b, default_order)?;
36        shape_b = shape;
37    }
38    let mut tensors_new = Vec::with_capacity(tensors.len());
39    for tensor in tensors {
40        let tensor = into_broadcast_f(tensor, shape_b.clone())?;
41        tensors_new.push(tensor);
42    }
43    return Ok(tensors_new);
44}
45
46/* #endregion */
47
48/* #region broadcast_to */
49
50pub fn into_broadcast_f<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> Result<TensorAny<R, T, B, D2>>
51where
52    R: DataAPI<Data = B::Raw>,
53    B: DeviceAPI<T>,
54    D: DimAPI + DimMaxAPI<D2, Max = D2>,
55    D2: DimAPI,
56{
57    let shape1 = tensor.shape();
58    let shape2 = &shape;
59    let default_order = tensor.device().default_order();
60    let (shape, tp1, _) = broadcast_shape(shape1, shape2, default_order)?;
61    let (storage, layout) = tensor.into_raw_parts();
62    let layout = update_layout_by_shape(&layout, &shape, &tp1, default_order)?;
63    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
64}
65
66/// Broadcasts an array to a specified shape.
67///
68/// # See also
69///
70/// [Python Array API standard: `broadcast_to`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.broadcast_to.html)
71pub fn to_broadcast<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> TensorView<'_, T, B, D2>
72where
73    D: DimAPI + DimMaxAPI<D2, Max = D2>,
74    D2: DimAPI,
75    R: DataAPI<Data = B::Raw>,
76    B: DeviceAPI<T>,
77{
78    into_broadcast_f(tensor.view(), shape).unwrap()
79}
80
81pub fn to_broadcast_f<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> Result<TensorView<'_, T, B, D2>>
82where
83    D: DimAPI + DimMaxAPI<D2, Max = D2>,
84    D2: DimAPI,
85    R: DataAPI<Data = B::Raw>,
86    B: DeviceAPI<T>,
87{
88    into_broadcast_f(tensor.view(), shape)
89}
90
91pub fn into_broadcast<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> TensorAny<R, T, B, D2>
92where
93    R: DataAPI<Data = B::Raw>,
94    B: DeviceAPI<T>,
95    D: DimAPI + DimMaxAPI<D2, Max = D2>,
96    D2: DimAPI,
97{
98    into_broadcast_f(tensor, shape).unwrap()
99}
100
101impl<R, T, B, D> TensorAny<R, T, B, D>
102where
103    R: DataAPI<Data = B::Raw>,
104    B: DeviceAPI<T>,
105    D: DimAPI,
106{
107    /// Broadcasts an array to a specified shape.
108    ///
109    /// # See also
110    ///
111    /// [`to_broadcast`]
112    pub fn to_broadcast<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
113    where
114        D2: DimAPI,
115        D: DimMaxAPI<D2, Max = D2>,
116    {
117        to_broadcast(self, shape)
118    }
119
120    pub fn to_broadcast_f<D2>(&self, shape: D2) -> Result<TensorView<'_, T, B, D2>>
121    where
122        D2: DimAPI,
123        D: DimMaxAPI<D2, Max = D2>,
124    {
125        to_broadcast_f(self, shape)
126    }
127
128    /// Broadcasts an array to a specified shape.
129    ///
130    /// # See also
131    ///
132    /// [`to_broadcast`]
133    pub fn into_broadcast<D2>(self, shape: D2) -> TensorAny<R, T, B, D2>
134    where
135        D2: DimAPI,
136        D: DimMaxAPI<D2, Max = D2>,
137    {
138        into_broadcast(self, shape)
139    }
140
141    pub fn into_broadcast_f<D2>(self, shape: D2) -> Result<TensorAny<R, T, B, D2>>
142    where
143        D2: DimAPI,
144        D: DimMaxAPI<D2, Max = D2>,
145    {
146        into_broadcast_f(self, shape)
147    }
148}
149
150/* #endregion */
151
152/* #region expand_dims */
153
154pub fn into_expand_dims_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, IxD>>
155where
156    D: DimAPI,
157    I: TryInto<AxesIndex<isize>, Error = Error>,
158{
159    // convert axis to negative indexes and sort
160    let ndim: isize = TryInto::<isize>::try_into(tensor.ndim())?;
161    let (storage, layout) = tensor.into_raw_parts();
162    let mut layout = layout.into_dim::<IxD>()?;
163    let mut axes: Vec<isize> =
164        axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v - ndim - 1 } else { v }).collect::<Vec<isize>>();
165    axes.sort();
166    for &axis in axes.iter() {
167        layout = layout.dim_insert(axis)?;
168    }
169    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
170}
171
172/// Expands the shape of an array by inserting a new axis (dimension) of size
173/// one at the position specified by `axis`.
174///
175/// # Panics
176///
177/// - If `axis` is greater than the number of axes in the original tensor.
178///
179/// # See also
180///
181/// [Python Array API standard: `expand_dims`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.expand_dims.html)
182pub fn expand_dims<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, IxD>
183where
184    D: DimAPI,
185    I: TryInto<AxesIndex<isize>, Error = Error>,
186    R: DataAPI<Data = B::Raw>,
187    B: DeviceAPI<T>,
188{
189    into_expand_dims_f(tensor.view(), axes).unwrap()
190}
191
192pub fn expand_dims_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, IxD>>
193where
194    D: DimAPI,
195    I: TryInto<AxesIndex<isize>, Error = Error>,
196    R: DataAPI<Data = B::Raw>,
197    B: DeviceAPI<T>,
198{
199    into_expand_dims_f(tensor.view(), axes)
200}
201
202pub fn into_expand_dims<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, IxD>
203where
204    D: DimAPI,
205    I: TryInto<AxesIndex<isize>, Error = Error>,
206{
207    into_expand_dims_f(tensor, axes).unwrap()
208}
209
210impl<R, T, B, D> TensorAny<R, T, B, D>
211where
212    R: DataAPI<Data = B::Raw>,
213    B: DeviceAPI<T>,
214    D: DimAPI,
215{
216    /// Expands the shape of an array by inserting a new axis (dimension) of
217    /// size one at the position specified by `axis`.
218    ///
219    /// # See also
220    ///
221    /// [`expand_dims`]
222    pub fn expand_dims<I>(&self, axes: I) -> TensorView<'_, T, B, IxD>
223    where
224        I: TryInto<AxesIndex<isize>, Error = Error>,
225    {
226        into_expand_dims(self.view(), axes)
227    }
228
229    pub fn expand_dims_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, IxD>>
230    where
231        I: TryInto<AxesIndex<isize>, Error = Error>,
232    {
233        into_expand_dims_f(self.view(), axes)
234    }
235
236    /// Expands the shape of an array by inserting a new axis (dimension) of
237    /// size one at the position specified by `axis`.
238    ///
239    /// # See also
240    ///
241    /// [`expand_dims`]
242    pub fn into_expand_dims<I>(self, axes: I) -> TensorAny<R, T, B, IxD>
243    where
244        I: TryInto<AxesIndex<isize>, Error = Error>,
245    {
246        into_expand_dims(self, axes)
247    }
248
249    pub fn into_expand_dims_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, IxD>>
250    where
251        I: TryInto<AxesIndex<isize>, Error = Error>,
252    {
253        into_expand_dims_f(self, axes)
254    }
255}
256
257/* #endregion */
258
259/* #region flip */
260
261pub fn into_flip_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, D>>
262where
263    D: DimAPI,
264    I: TryInto<AxesIndex<isize>, Error = Error>,
265{
266    let (storage, mut layout) = tensor.into_raw_parts();
267    let axes = axes.try_into()?;
268    match axes {
269        AxesIndex::Val(axis) => {
270            layout = layout.dim_narrow(axis, slice!(None, None, -1))?;
271        },
272        AxesIndex::Vec(axes) => {
273            for &axis in axes.iter() {
274                layout = layout.dim_narrow(axis, slice!(None, None, -1))?;
275            }
276        },
277    }
278    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
279}
280
281/// Reverses the order of elements in an array along the given axis.
282///
283/// # Panics
284///
285/// - If some index in `axis` is greater than the number of axes in the original tensor.
286///
287/// # See also
288///
289/// [Python array API standard: `flip`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.flip.html)
290pub fn flip<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, D>
291where
292    D: DimAPI,
293    I: TryInto<AxesIndex<isize>, Error = Error>,
294    R: DataAPI<Data = B::Raw>,
295    B: DeviceAPI<T>,
296{
297    into_flip_f(tensor.view(), axes).unwrap()
298}
299
300pub fn flip_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, D>>
301where
302    D: DimAPI,
303    I: TryInto<AxesIndex<isize>, Error = Error>,
304    R: DataAPI<Data = B::Raw>,
305    B: DeviceAPI<T>,
306{
307    into_flip_f(tensor.view(), axes)
308}
309
310pub fn into_flip<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, D>
311where
312    D: DimAPI,
313    I: TryInto<AxesIndex<isize>, Error = Error>,
314{
315    into_flip_f(tensor, axes).unwrap()
316}
317
318impl<R, T, B, D> TensorAny<R, T, B, D>
319where
320    R: DataAPI<Data = B::Raw>,
321    B: DeviceAPI<T>,
322    D: DimAPI,
323{
324    /// Reverses the order of elements in an array along the given axis.
325    ///
326    /// # See also
327    ///
328    /// [`flip`]
329    pub fn flip<I>(&self, axis: I) -> TensorView<'_, T, B, D>
330    where
331        I: TryInto<AxesIndex<isize>, Error = Error>,
332    {
333        flip(self, axis)
334    }
335
336    pub fn flip_f<I>(&self, axis: I) -> Result<TensorView<'_, T, B, D>>
337    where
338        I: TryInto<AxesIndex<isize>, Error = Error>,
339    {
340        flip_f(self, axis)
341    }
342
343    /// Reverses the order of elements in an array along the given axis.
344    ///
345    /// # See also
346    ///
347    /// [`flip`]
348    pub fn into_flip<I>(self, axis: I) -> TensorAny<R, T, B, D>
349    where
350        I: TryInto<AxesIndex<isize>, Error = Error>,
351    {
352        into_flip(self, axis)
353    }
354
355    pub fn into_flip_f<I>(self, axis: I) -> Result<TensorAny<R, T, B, D>>
356    where
357        I: TryInto<AxesIndex<isize>, Error = Error>,
358    {
359        into_flip_f(self, axis)
360    }
361}
362
363/* #endregion */
364
365/* #region permute_dims */
366
367pub fn into_transpose_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, D>>
368where
369    D: DimAPI,
370    I: TryInto<AxesIndex<isize>, Error = Error>,
371{
372    let axes = axes.try_into()?;
373    if axes.as_ref().is_empty() {
374        return Ok(into_reverse_axes(tensor));
375    }
376    let (storage, layout) = tensor.into_raw_parts();
377    let layout = layout.transpose(axes.as_ref())?;
378    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
379}
380
381/// Permutes the axes (dimensions) of an array `x`.
382///
383/// # See also
384///
385/// - [Python array API standard: `permute_dims`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.permute_dims.html)
386pub fn transpose<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, D>
387where
388    D: DimAPI,
389    I: TryInto<AxesIndex<isize>, Error = Error>,
390    R: DataAPI<Data = B::Raw>,
391    B: DeviceAPI<T>,
392{
393    into_transpose_f(tensor.view(), axes).unwrap()
394}
395
396pub fn transpose_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, D>>
397where
398    D: DimAPI,
399    I: TryInto<AxesIndex<isize>, Error = Error>,
400    R: DataAPI<Data = B::Raw>,
401    B: DeviceAPI<T>,
402{
403    into_transpose_f(tensor.view(), axes)
404}
405
406pub fn into_transpose<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, D>
407where
408    D: DimAPI,
409    I: TryInto<AxesIndex<isize>, Error = Error>,
410{
411    into_transpose_f(tensor, axes).unwrap()
412}
413
414pub use into_transpose as into_permute_dims;
415pub use into_transpose_f as into_permute_dims_f;
416pub use transpose as permute_dims;
417pub use transpose_f as permute_dims_f;
418
419impl<R, T, B, D> TensorAny<R, T, B, D>
420where
421    R: DataAPI<Data = B::Raw>,
422    B: DeviceAPI<T>,
423    D: DimAPI,
424{
425    /// Permutes the axes (dimensions) of an array `x`.
426    ///
427    /// # See also
428    ///
429    /// [`transpose`]
430    pub fn transpose<I>(&self, axes: I) -> TensorView<'_, T, B, D>
431    where
432        I: TryInto<AxesIndex<isize>, Error = Error>,
433    {
434        transpose(self, axes)
435    }
436
437    pub fn transpose_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, D>>
438    where
439        I: TryInto<AxesIndex<isize>, Error = Error>,
440    {
441        transpose_f(self, axes)
442    }
443
444    /// Permutes the axes (dimensions) of an array `x`.
445    ///
446    /// # See also
447    ///
448    /// [`transpose`]
449    pub fn into_transpose<I>(self, axes: I) -> TensorAny<R, T, B, D>
450    where
451        I: TryInto<AxesIndex<isize>, Error = Error>,
452    {
453        into_transpose(self, axes)
454    }
455
456    pub fn into_transpose_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, D>>
457    where
458        I: TryInto<AxesIndex<isize>, Error = Error>,
459    {
460        into_transpose_f(self, axes)
461    }
462
463    /// Permutes the axes (dimensions) of an array `x`.
464    ///
465    /// # See also
466    ///
467    /// [`transpose`]
468    pub fn permute_dims<I>(&self, axes: I) -> TensorView<'_, T, B, D>
469    where
470        I: TryInto<AxesIndex<isize>, Error = Error>,
471    {
472        transpose(self, axes)
473    }
474
475    pub fn permute_dims_f<I>(&self, axes: I) -> Result<TensorView<'_, T, B, D>>
476    where
477        I: TryInto<AxesIndex<isize>, Error = Error>,
478    {
479        transpose_f(self, axes)
480    }
481
482    /// Permutes the axes (dimensions) of an array `x`.
483    ///
484    /// # See also
485    ///
486    /// [`transpose`]
487    pub fn into_permute_dims<I>(self, axes: I) -> TensorAny<R, T, B, D>
488    where
489        I: TryInto<AxesIndex<isize>, Error = Error>,
490    {
491        into_transpose(self, axes)
492    }
493
494    pub fn into_permute_dims_f<I>(self, axes: I) -> Result<TensorAny<R, T, B, D>>
495    where
496        I: TryInto<AxesIndex<isize>, Error = Error>,
497    {
498        into_transpose_f(self, axes)
499    }
500}
501
502/* #endregion */
503
504/* #region reverse_axes */
505
506pub fn into_reverse_axes<S, D>(tensor: TensorBase<S, D>) -> TensorBase<S, D>
507where
508    D: DimAPI,
509{
510    let (storage, layout) = tensor.into_raw_parts();
511    let layout = layout.reverse_axes();
512    unsafe { TensorBase::new_unchecked(storage, layout) }
513}
514
515/// Reverse the order of elements in an array along the given axis.
516pub fn reverse_axes<R, T, B, D>(tensor: &TensorAny<R, T, B, D>) -> TensorView<'_, T, B, D>
517where
518    D: DimAPI,
519    R: DataAPI<Data = B::Raw>,
520    B: DeviceAPI<T>,
521{
522    into_reverse_axes(tensor.view())
523}
524
525impl<R, T, B, D> TensorAny<R, T, B, D>
526where
527    R: DataAPI<Data = B::Raw>,
528    B: DeviceAPI<T>,
529    D: DimAPI,
530{
531    /// Reverse the order of elements in an array along the given axis.
532    ///
533    /// # See also
534    ///
535    /// [`reverse_axes`]
536    pub fn reverse_axes(&self) -> TensorView<'_, T, B, D> {
537        into_reverse_axes(self.view())
538    }
539
540    /// Reverse the order of elements in an array along the given axis.
541    ///
542    /// # See also
543    ///
544    /// [`reverse_axes`]
545    pub fn into_reverse_axes(self) -> TensorAny<R, T, B, D> {
546        into_reverse_axes(self)
547    }
548
549    /// Reverse the order of elements in an array along the given axis.
550    ///
551    /// # See also
552    ///
553    /// [`reverse_axes`]
554    pub fn t(&self) -> TensorView<'_, T, B, D> {
555        into_reverse_axes(self.view())
556    }
557}
558
559/* #endregion */
560
561/* #region swapaxes */
562
563pub fn into_swapaxes_f<I, S, D>(tensor: TensorBase<S, D>, axis1: I, axis2: I) -> Result<TensorBase<S, D>>
564where
565    D: DimAPI,
566    I: TryInto<isize>,
567{
568    let axis1 = axis1.try_into().map_err(|_| rstsr_error!(TryFromIntError))?;
569    let axis2 = axis2.try_into().map_err(|_| rstsr_error!(TryFromIntError))?;
570    let (storage, layout) = tensor.into_raw_parts();
571    let layout = layout.swapaxes(axis1, axis2)?;
572    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
573}
574
575/// Interchange two axes of an array.
576///
577/// # See also
578///
579/// - [numpy `swapaxes`](https://numpy.org/doc/stable/reference/generated/numpy.swapaxes.html)
580pub fn swapaxes<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axis1: I, axis2: I) -> TensorView<'_, T, B, D>
581where
582    D: DimAPI,
583    I: TryInto<isize>,
584    R: DataAPI<Data = B::Raw>,
585    B: DeviceAPI<T>,
586{
587    into_swapaxes_f(tensor.view(), axis1, axis2).unwrap()
588}
589
590pub fn swapaxes_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axis1: I, axis2: I) -> Result<TensorView<'_, T, B, D>>
591where
592    D: DimAPI,
593    I: TryInto<isize>,
594    R: DataAPI<Data = B::Raw>,
595    B: DeviceAPI<T>,
596{
597    into_swapaxes_f(tensor.view(), axis1, axis2)
598}
599
600pub fn into_swapaxes<I, S, D>(tensor: TensorBase<S, D>, axis1: I, axis2: I) -> TensorBase<S, D>
601where
602    D: DimAPI,
603    I: TryInto<isize>,
604{
605    into_swapaxes_f(tensor, axis1, axis2).unwrap()
606}
607
608impl<R, T, B, D> TensorAny<R, T, B, D>
609where
610    R: DataAPI<Data = B::Raw>,
611    B: DeviceAPI<T>,
612    D: DimAPI,
613{
614    /// Interchange two axes of an array.
615    ///
616    /// # See also
617    ///
618    /// [`swapaxes`]
619    pub fn swapaxes<I>(&self, axis1: I, axis2: I) -> TensorView<'_, T, B, D>
620    where
621        I: TryInto<isize>,
622    {
623        swapaxes(self, axis1, axis2)
624    }
625
626    pub fn swapaxes_f<I>(&self, axis1: I, axis2: I) -> Result<TensorView<'_, T, B, D>>
627    where
628        I: TryInto<isize>,
629    {
630        swapaxes_f(self, axis1, axis2)
631    }
632
633    /// Interchange two axes of an array.
634    ///
635    /// # See also
636    ///
637    /// [`swapaxes`]
638    pub fn into_swapaxes<I>(self, axis1: I, axis2: I) -> TensorAny<R, T, B, D>
639    where
640        I: TryInto<isize>,
641    {
642        into_swapaxes(self, axis1, axis2)
643    }
644
645    pub fn into_swapaxes_f<I>(self, axis1: I, axis2: I) -> Result<TensorAny<R, T, B, D>>
646    where
647        I: TryInto<isize>,
648    {
649        into_swapaxes_f(self, axis1, axis2)
650    }
651}
652
653/* #endregion */
654
655/* #region squeeze */
656
657pub fn into_squeeze_f<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> Result<TensorBase<S, IxD>>
658where
659    D: DimAPI,
660    I: TryInto<AxesIndex<isize>, Error = Error>,
661{
662    // convert axis to positive indexes and (reversed) sort
663    let ndim: isize = TryInto::<isize>::try_into(tensor.ndim())?;
664    let (storage, layout) = tensor.into_raw_parts();
665    let mut layout = layout.into_dim::<IxD>()?;
666    let mut axes: Vec<isize> =
667        axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect::<_>();
668    axes.sort_by(|a, b| b.cmp(a));
669    if axes.first().is_some_and(|&v| v < 0) {
670        return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
671    }
672    // check no two axis are the same
673    for i in 0..axes.len() - 1 {
674        rstsr_assert!(axes[i] != axes[i + 1], InvalidValue, "Same axes is not allowed here.")?;
675    }
676    // perform squeeze
677    for &axis in axes.iter() {
678        layout = layout.dim_eliminate(axis)?;
679    }
680    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
681}
682
683/// Removes singleton dimensions (axes) from `x`.
684///
685/// # See also
686///
687/// [Python array API standard: `squeeze`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.squeeze.html)
688pub fn squeeze<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> TensorView<'_, T, B, IxD>
689where
690    D: DimAPI,
691    I: TryInto<AxesIndex<isize>, Error = Error>,
692    R: DataAPI<Data = B::Raw>,
693    B: DeviceAPI<T>,
694{
695    into_squeeze_f(tensor.view(), axes).unwrap()
696}
697
698pub fn squeeze_f<I, R, T, B, D>(tensor: &TensorAny<R, T, B, D>, axes: I) -> Result<TensorView<'_, T, B, IxD>>
699where
700    D: DimAPI,
701    I: TryInto<AxesIndex<isize>, Error = Error>,
702    R: DataAPI<Data = B::Raw>,
703    B: DeviceAPI<T>,
704{
705    into_squeeze_f(tensor.view(), axes)
706}
707
708pub fn into_squeeze<I, S, D>(tensor: TensorBase<S, D>, axes: I) -> TensorBase<S, IxD>
709where
710    D: DimAPI,
711    I: TryInto<AxesIndex<isize>, Error = Error>,
712{
713    into_squeeze_f(tensor, axes).unwrap()
714}
715
716impl<R, T, B, D> TensorAny<R, T, B, D>
717where
718    R: DataAPI<Data = B::Raw>,
719    B: DeviceAPI<T>,
720    D: DimAPI,
721{
722    /// Removes singleton dimensions (axes) from `x`.
723    ///
724    /// # See also
725    ///
726    /// [`squeeze`]
727    pub fn squeeze<I>(&self, axis: I) -> TensorView<'_, T, B, IxD>
728    where
729        I: TryInto<AxesIndex<isize>, Error = Error>,
730    {
731        squeeze(self, axis)
732    }
733
734    pub fn squeeze_f<I>(&self, axis: I) -> Result<TensorView<'_, T, B, IxD>>
735    where
736        I: TryInto<AxesIndex<isize>, Error = Error>,
737    {
738        squeeze_f(self, axis)
739    }
740
741    /// Removes singleton dimensions (axes) from `x`.
742    ///
743    /// # See also
744    ///
745    /// [`squeeze`]
746    pub fn into_squeeze<I>(self, axis: I) -> TensorAny<R, T, B, IxD>
747    where
748        I: TryInto<AxesIndex<isize>, Error = Error>,
749    {
750        into_squeeze(self, axis)
751    }
752
753    pub fn into_squeeze_f<I>(self, axis: I) -> Result<TensorAny<R, T, B, IxD>>
754    where
755        I: TryInto<AxesIndex<isize>, Error = Error>,
756    {
757        into_squeeze_f(self, axis)
758    }
759}
760
761/* #endregion */
762
763/* #region into_dim */
764
765pub fn into_dim_f<S, D, D2>(tensor: TensorBase<S, D>) -> Result<TensorBase<S, D2>>
766where
767    D: DimAPI + DimIntoAPI<D2>,
768    D2: DimAPI,
769{
770    let (storage, layout) = tensor.into_raw_parts();
771    let layout = layout.into_dim::<D2>()?;
772    unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
773}
774
775/// Convert layout to the other dimension.
776///
777/// This is mostly used when converting static dimension to dynamic
778/// dimension or vice versa.
779pub fn to_dim<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>) -> TensorView<'_, T, B, D2>
780where
781    D: DimAPI,
782    D2: DimAPI,
783    D: DimIntoAPI<D2>,
784    R: DataAPI<Data = B::Raw>,
785    B: DeviceAPI<T>,
786{
787    into_dim_f(tensor.view()).unwrap()
788}
789
790pub fn to_dim_f<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>) -> Result<TensorView<'_, T, B, D2>>
791where
792    D: DimAPI,
793    D2: DimAPI,
794    D: DimIntoAPI<D2>,
795    R: DataAPI<Data = B::Raw>,
796    B: DeviceAPI<T>,
797{
798    into_dim_f(tensor.view())
799}
800
801pub fn into_dim<S, D, D2>(tensor: TensorBase<S, D>) -> TensorBase<S, D2>
802where
803    D: DimAPI,
804    D2: DimAPI,
805    D: DimIntoAPI<D2>,
806{
807    into_dim_f(tensor).unwrap()
808}
809
810pub fn to_dyn<R, T, B, D>(tensor: &TensorAny<R, T, B, D>) -> TensorView<'_, T, B, IxD>
811where
812    D: DimAPI,
813    R: DataAPI<Data = B::Raw>,
814    B: DeviceAPI<T>,
815{
816    into_dim_f(tensor.view()).unwrap()
817}
818
819pub fn into_dyn<S, D>(tensor: TensorBase<S, D>) -> TensorBase<S, IxD>
820where
821    D: DimAPI,
822{
823    into_dim_f(tensor).unwrap()
824}
825
826impl<R, T, B, D> TensorAny<R, T, B, D>
827where
828    D: DimAPI,
829    R: DataAPI<Data = B::Raw>,
830    B: DeviceAPI<T>,
831{
832    /// Convert layout to the other dimension.
833    ///
834    /// This is mostly used when converting static dimension to dynamic
835    /// dimension or vice versa.
836    ///
837    /// # See also
838    ///
839    /// [`into_dim`]
840    pub fn to_dim<D2>(&self) -> TensorView<'_, T, B, D2>
841    where
842        D2: DimAPI,
843        D: DimIntoAPI<D2>,
844    {
845        to_dim(self)
846    }
847
848    pub fn to_dim_f<D2>(&self) -> Result<TensorView<'_, T, B, D2>>
849    where
850        D2: DimAPI,
851        D: DimIntoAPI<D2>,
852    {
853        to_dim_f(self)
854    }
855
856    /// Convert layout to another dimension.
857    ///
858    /// # See also
859    ///
860    /// [`into_dim`]
861    pub fn into_dim<D2>(self) -> TensorAny<R, T, B, D2>
862    where
863        D2: DimAPI,
864        D: DimIntoAPI<D2>,
865    {
866        into_dim(self)
867    }
868
869    pub fn into_dim_f<D2>(self) -> Result<TensorAny<R, T, B, D2>>
870    where
871        D2: DimAPI,
872        D: DimIntoAPI<D2>,
873    {
874        into_dim_f(self)
875    }
876
877    /// Convert layout to dynamic dimension.
878    pub fn to_dyn(&self) -> TensorView<'_, T, B, IxD> {
879        to_dyn(self)
880    }
881
882    /// Convert layout to dynamic dimension.
883    pub fn into_dyn(self) -> TensorAny<R, T, B, IxD> {
884        into_dyn(self)
885    }
886}
887
888/* #endregion */
889
890/* #region reshape_assume_contig */
891
892pub fn into_shape_assume_contig_f<R, T, B, D, D2>(
893    tensor: TensorAny<R, T, B, D>,
894    shape: D2,
895) -> Result<TensorAny<R, T, B, D2>>
896where
897    R: DataAPI<Data = B::Raw>,
898    B: DeviceAPI<T>,
899    D: DimAPI,
900    D2: DimAPI,
901{
902    let default_order = tensor.device().default_order();
903    let (storage, layout) = tensor.into_raw_parts();
904
905    rstsr_assert_eq!(layout.size(), shape.shape_size(), InvalidLayout, "Number of elements not same.")?;
906
907    let new_layout = {
908        if default_order == FlagOrder::C && layout.c_contig() {
909            shape.new_c_contig(Some(layout.offset()))
910        } else if default_order == FlagOrder::F && layout.f_contig() {
911            shape.new_f_contig(Some(layout.offset()))
912        } else {
913            rstsr_raise!(InvalidLayout, "This array is not contiguous by {:?}", default_order)?
914        }
915    };
916    unsafe { Ok(TensorBase::new_unchecked(storage, new_layout)) }
917}
918
919/// Assuming contiguous array, reshapes an array without changing its data.
920///
921/// This function may return c-contiguous or f-contiguous array depending on
922/// crate feature `f_prefer`.
923///
924/// # See also
925///
926/// [Python array API standard: `reshape`](https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.reshape.html)
927pub fn to_shape_assume_contig<R, T, B, D, D2>(tensor: &TensorAny<R, T, B, D>, shape: D2) -> TensorView<'_, T, B, D2>
928where
929    D: DimAPI,
930    D2: DimAPI,
931    R: DataAPI<Data = B::Raw>,
932    B: DeviceAPI<T>,
933{
934    into_shape_assume_contig_f(tensor.view(), shape).unwrap()
935}
936
937pub fn to_shape_assume_contig_f<R, T, B, D, D2>(
938    tensor: &TensorAny<R, T, B, D>,
939    shape: D2,
940) -> Result<TensorView<'_, T, B, D2>>
941where
942    D: DimAPI,
943    D2: DimAPI,
944    R: DataAPI<Data = B::Raw>,
945    B: DeviceAPI<T>,
946{
947    into_shape_assume_contig_f(tensor.view(), shape)
948}
949
950pub fn into_shape_assume_contig<R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, shape: D2) -> TensorAny<R, T, B, D2>
951where
952    R: DataAPI<Data = B::Raw>,
953    B: DeviceAPI<T>,
954    D: DimAPI,
955    D2: DimAPI,
956{
957    into_shape_assume_contig_f(tensor, shape).unwrap()
958}
959
960pub use to_shape_assume_contig as reshape_assume_contig;
961pub use to_shape_assume_contig_f as reshape_assume_contig_f;
962
963impl<R, T, B, D> TensorAny<R, T, B, D>
964where
965    R: DataAPI<Data = B::Raw>,
966    B: DeviceAPI<T>,
967    D: DimAPI,
968{
969    /// Assuming contiguous array, reshapes an array without changing its data.
970    ///
971    /// # See also
972    ///
973    /// [`reshape_assume_contig`]
974    pub fn reshape_assume_contig<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
975    where
976        D2: DimAPI,
977    {
978        into_shape_assume_contig(self.view(), shape)
979    }
980
981    pub fn reshape_assume_contig_f<D2>(&self, shape: D2) -> Result<TensorView<'_, T, B, D2>>
982    where
983        D2: DimAPI,
984    {
985        into_shape_assume_contig_f(self.view(), shape)
986    }
987
988    pub fn to_shape_assume_contig<D2>(&self, shape: D2) -> TensorView<'_, T, B, D2>
989    where
990        D2: DimAPI,
991    {
992        into_shape_assume_contig(self.view(), shape)
993    }
994
995    pub fn to_shape_assume_contig_f<D2>(&self, shape: D2) -> Result<TensorView<'_, T, B, D2>>
996    where
997        D2: DimAPI,
998    {
999        into_shape_assume_contig_f(self.view(), shape)
1000    }
1001
1002    pub fn into_shape_assume_contig<D2>(self, shape: D2) -> TensorAny<R, T, B, D2>
1003    where
1004        D2: DimAPI,
1005    {
1006        into_shape_assume_contig(self, shape)
1007    }
1008
1009    pub fn into_shape_assume_contig_f<D2>(self, shape: D2) -> Result<TensorAny<R, T, B, D2>>
1010    where
1011        D2: DimAPI,
1012    {
1013        into_shape_assume_contig_f(self, shape)
1014    }
1015}
1016
1017/* #endregion */
1018
1019/* #region reshape */
1020
1021pub fn change_shape_f<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
1022where
1023    I: TryInto<AxesIndex<isize>, Error = Error>,
1024    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1025    D: DimAPI,
1026    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
1027{
1028    // own shape, this is cheap operation
1029    let shape_new = reshape_substitute_negatives(shape.try_into()?.as_ref(), tensor.size())?;
1030    let default_order = tensor.device().default_order();
1031    if let Some(layout_new) = layout_reshapeable(&tensor.layout().to_dim()?, &shape_new, default_order)? {
1032        // shape does not need to be changed
1033        let (storage, _) = tensor.into_raw_parts();
1034        let layout = layout_new.into_dim::<IxD>()?;
1035        return unsafe { Ok(TensorBase::new_unchecked(storage, layout).into_cow()) };
1036    } else {
1037        // clone underlying data by assign_arbitary
1038        let (storage, layout) = tensor.into_raw_parts();
1039        let device = storage.device();
1040        let layout_new = match default_order {
1041            RowMajor => shape_new.new_c_contig(None),
1042            ColMajor => shape_new.new_f_contig(None),
1043        };
1044        let mut storage_new = unsafe { device.empty_impl(layout_new.size())? };
1045        device.assign_arbitary(storage_new.raw_mut(), &layout_new, storage.raw(), &layout)?;
1046        return unsafe { Ok(TensorBase::new_unchecked(storage_new, layout_new).into_cow()) };
1047    }
1048}
1049
1050pub fn change_shape<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> TensorCow<'a, T, B, IxD>
1051where
1052    I: TryInto<AxesIndex<isize>, Error = Error>,
1053    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1054    D: DimAPI,
1055    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
1056{
1057    change_shape_f(tensor, shape).unwrap()
1058}
1059
1060pub fn into_shape_f<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> Result<Tensor<T, B, IxD>>
1061where
1062    I: TryInto<AxesIndex<isize>, Error = Error>,
1063    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1064    D: DimAPI,
1065    T: Clone,
1066    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D> + OpAssignAPI<T, IxD>,
1067    B::Raw: Clone + 'a,
1068{
1069    change_shape_f(tensor, shape).map(|v| v.into_owned())
1070}
1071
1072pub fn into_shape<'a, I, R, T, B, D>(tensor: TensorAny<R, T, B, D>, shape: I) -> Tensor<T, B, IxD>
1073where
1074    I: TryInto<AxesIndex<isize>, Error = Error>,
1075    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1076    D: DimAPI,
1077    T: Clone,
1078    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D> + OpAssignAPI<T, IxD>,
1079    B::Raw: Clone + 'a,
1080{
1081    into_shape_f(tensor, shape).unwrap()
1082}
1083
1084pub fn to_shape_f<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
1085where
1086    I: TryInto<AxesIndex<isize>, Error = Error>,
1087    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1088    D: DimAPI,
1089    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
1090{
1091    change_shape_f(tensor.view(), shape)
1092}
1093
1094pub fn to_shape<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> TensorCow<'a, T, B, IxD>
1095where
1096    I: TryInto<AxesIndex<isize>, Error = Error>,
1097    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1098    D: DimAPI,
1099    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
1100{
1101    to_shape_f(tensor, shape).unwrap()
1102}
1103
1104pub fn reshape_f<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
1105where
1106    I: TryInto<AxesIndex<isize>, Error = Error>,
1107    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1108    D: DimAPI,
1109    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
1110{
1111    to_shape_f(tensor, shape)
1112}
1113
1114pub fn reshape<'a, I, R, T, B, D>(tensor: &'a TensorAny<R, T, B, D>, shape: I) -> TensorCow<'a, T, B, IxD>
1115where
1116    I: TryInto<AxesIndex<isize>, Error = Error>,
1117    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1118    D: DimAPI,
1119    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D>,
1120{
1121    to_shape(tensor, shape)
1122}
1123
1124impl<'a, R, T, B, D> TensorAny<R, T, B, D>
1125where
1126    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1127    D: DimAPI,
1128    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, IxD, D> + OpAssignAPI<T, IxD>,
1129    T: Clone,
1130{
1131    pub fn change_shape_f<I>(self, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
1132    where
1133        I: TryInto<AxesIndex<isize>, Error = Error>,
1134    {
1135        change_shape_f(self, shape)
1136    }
1137
1138    pub fn change_shape<I>(self, shape: I) -> TensorCow<'a, T, B, IxD>
1139    where
1140        I: TryInto<AxesIndex<isize>, Error = Error>,
1141    {
1142        change_shape(self, shape)
1143    }
1144
1145    pub fn into_shape_f<I>(self, shape: I) -> Result<Tensor<T, B, IxD>>
1146    where
1147        I: TryInto<AxesIndex<isize>, Error = Error>,
1148        B::Raw: Clone + 'a,
1149    {
1150        into_shape_f(self, shape)
1151    }
1152
1153    pub fn into_shape<I>(self, shape: I) -> Tensor<T, B, IxD>
1154    where
1155        I: TryInto<AxesIndex<isize>, Error = Error>,
1156        B::Raw: Clone + 'a,
1157    {
1158        into_shape(self, shape)
1159    }
1160
1161    pub fn to_shape_f<I>(&'a self, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
1162    where
1163        I: TryInto<AxesIndex<isize>, Error = Error>,
1164    {
1165        self.view().change_shape_f(shape)
1166    }
1167
1168    pub fn to_shape<I>(&'a self, shape: I) -> TensorCow<'a, T, B, IxD>
1169    where
1170        I: TryInto<AxesIndex<isize>, Error = Error>,
1171    {
1172        self.view().change_shape(shape)
1173    }
1174
1175    pub fn reshape_f<I>(&'a self, shape: I) -> Result<TensorCow<'a, T, B, IxD>>
1176    where
1177        I: TryInto<AxesIndex<isize>, Error = Error>,
1178    {
1179        self.view().change_shape_f(shape)
1180    }
1181
1182    pub fn reshape<I>(&'a self, shape: I) -> TensorCow<'a, T, B, IxD>
1183    where
1184        I: TryInto<AxesIndex<isize>, Error = Error>,
1185    {
1186        self.view().change_shape(shape)
1187    }
1188}
1189
1190/* #endregion */
1191
1192/* #region to_layout */
1193
1194pub fn change_layout_f<'a, R, T, B, D, D2>(
1195    tensor: TensorAny<R, T, B, D>,
1196    layout: Layout<D2>,
1197) -> Result<TensorCow<'a, T, B, D2>>
1198where
1199    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1200    D: DimAPI,
1201    D2: DimAPI,
1202    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D>,
1203{
1204    let shape = layout.shape();
1205    rstsr_assert_eq!(tensor.size(), shape.shape_size(), InvalidLayout)?;
1206    let same_layout = tensor.layout().to_dim::<IxD>()? == layout.to_dim::<IxD>()?;
1207    let contig_c = tensor.c_contig() && layout.c_contig() && tensor.layout().offset() == layout.offset();
1208    let contig_f = tensor.f_contig() && layout.f_contig() && tensor.layout().offset() == layout.offset();
1209    let default_order = tensor.device().default_order();
1210    let contig = match default_order {
1211        RowMajor => contig_c,
1212        ColMajor => contig_f,
1213    };
1214    if same_layout || contig {
1215        // no data cloned
1216        let (storage, _) = tensor.into_raw_parts();
1217        let tensor = unsafe { TensorBase::new_unchecked(storage, layout) };
1218        return Ok(tensor.into_cow());
1219    } else {
1220        // layout changed, or not c and f contiguous with same layout
1221        // clone data by assign
1222        let (storage_old, layout_old) = tensor.into_raw_parts();
1223        let device = storage_old.device();
1224        let (_, idx_max) = layout.bounds_index()?;
1225        let mut storage_new = unsafe { device.empty_impl(idx_max)? };
1226        device.assign_arbitary(storage_new.raw_mut(), &layout, storage_old.raw(), &layout_old)?;
1227        let tensor = unsafe { TensorBase::new_unchecked(storage_new, layout) };
1228        return Ok(tensor.into_cow());
1229    }
1230}
1231
1232/// Convert tensor to the other layout.
1233pub fn to_layout<R, T, D, B, D2>(tensor: &TensorAny<R, T, B, D>, layout: Layout<D2>) -> TensorCow<'_, T, B, D2>
1234where
1235    R: DataAPI<Data = B::Raw>,
1236    D: DimAPI,
1237    D2: DimAPI,
1238    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D>,
1239{
1240    change_layout_f(tensor.view(), layout).unwrap()
1241}
1242
1243pub fn to_layout_f<R, T, D, B, D2>(
1244    tensor: &TensorAny<R, T, B, D>,
1245    layout: Layout<D2>,
1246) -> Result<TensorCow<'_, T, B, D2>>
1247where
1248    R: DataAPI<Data = B::Raw>,
1249    D: DimAPI,
1250    D2: DimAPI,
1251    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D>,
1252{
1253    change_layout_f(tensor.view(), layout)
1254}
1255
1256pub fn into_layout_f<'a, R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, layout: Layout<D2>) -> Result<Tensor<T, B, D2>>
1257where
1258    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1259    D: DimAPI,
1260    D2: DimAPI,
1261    T: Clone,
1262    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D> + OpAssignAPI<T, D2>,
1263    B::Raw: Clone + 'a,
1264{
1265    change_layout_f(tensor, layout).map(|v| v.into_owned())
1266}
1267
1268pub fn into_layout<'a, R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, layout: Layout<D2>) -> Tensor<T, B, D2>
1269where
1270    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1271    D: DimAPI,
1272    D2: DimAPI,
1273    T: Clone,
1274    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D> + OpAssignAPI<T, D2>,
1275    B::Raw: Clone + 'a,
1276{
1277    into_layout_f(tensor, layout).unwrap()
1278}
1279
1280pub fn change_layout<'a, R, T, B, D, D2>(tensor: TensorAny<R, T, B, D>, layout: Layout<D2>) -> TensorCow<'a, T, B, D2>
1281where
1282    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1283    D: DimAPI,
1284    D2: DimAPI,
1285    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D2, D>,
1286{
1287    change_layout_f(tensor, layout).unwrap()
1288}
1289
1290impl<'a, R, T, B, D> TensorAny<R, T, B, D>
1291where
1292    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1293    D: DimAPI,
1294    T: Clone,
1295    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
1296{
1297    /// Convert tensor to the other layout.
1298    ///
1299    /// # See also
1300    ///
1301    /// [`to_layout`]
1302    pub fn to_layout<D2>(&self, layout: Layout<D2>) -> TensorCow<'_, T, B, D2>
1303    where
1304        D2: DimAPI,
1305        B: OpAssignArbitaryAPI<T, D2, D>,
1306    {
1307        to_layout(self, layout)
1308    }
1309
1310    pub fn to_layout_f<D2>(&self, layout: Layout<D2>) -> Result<TensorCow<'_, T, B, D2>>
1311    where
1312        D2: DimAPI,
1313        B: OpAssignArbitaryAPI<T, D2, D>,
1314    {
1315        to_layout_f(self, layout)
1316    }
1317
1318    pub fn into_layout_f<D2>(self, layout: Layout<D2>) -> Result<Tensor<T, B, D2>>
1319    where
1320        D2: DimAPI,
1321        B: OpAssignArbitaryAPI<T, D2, D> + OpAssignAPI<T, D2>,
1322        B::Raw: Clone + 'a,
1323    {
1324        into_layout_f(self, layout)
1325    }
1326
1327    pub fn into_layout<D2>(self, layout: Layout<D2>) -> Tensor<T, B, D2>
1328    where
1329        D2: DimAPI,
1330        B: OpAssignArbitaryAPI<T, D2, D> + OpAssignAPI<T, D2>,
1331        B::Raw: Clone + 'a,
1332    {
1333        into_layout(self, layout)
1334    }
1335
1336    pub fn change_layout_f<D2>(self, layout: Layout<D2>) -> Result<TensorCow<'a, T, B, D2>>
1337    where
1338        D2: DimAPI,
1339        B: OpAssignArbitaryAPI<T, D2, D>,
1340    {
1341        change_layout_f(self, layout)
1342    }
1343
1344    pub fn change_layout<D2>(self, layout: Layout<D2>) -> TensorCow<'a, T, B, D2>
1345    where
1346        D2: DimAPI,
1347        B: OpAssignArbitaryAPI<T, D2, D>,
1348    {
1349        change_layout(self, layout)
1350    }
1351}
1352
1353/* #endregion */
1354
1355/* #region to_contig */
1356
1357pub fn change_contig_f<'a, R, T, B, D>(
1358    tensor: TensorAny<R, T, B, D>,
1359    order: FlagOrder,
1360) -> Result<TensorCow<'a, T, B, D>>
1361where
1362    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1363    D: DimAPI,
1364    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1365{
1366    let shape = tensor.shape();
1367    let layout_new = match order {
1368        RowMajor => shape.new_c_contig(None),
1369        ColMajor => shape.new_f_contig(None),
1370    };
1371    change_layout_f(tensor, layout_new)
1372}
1373
1374pub fn to_contig_f<R, T, B, D>(tensor: &TensorAny<R, T, B, D>, order: FlagOrder) -> Result<TensorCow<'_, T, B, D>>
1375where
1376    R: DataAPI<Data = B::Raw>,
1377    D: DimAPI,
1378    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1379{
1380    change_contig_f(tensor.view(), order)
1381}
1382
1383pub fn into_contig_f<'a, R, T, B, D>(tensor: TensorAny<R, T, B, D>, order: FlagOrder) -> Result<Tensor<T, B, D>>
1384where
1385    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1386    D: DimAPI,
1387    T: Clone,
1388    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1389    B::Raw: Clone + 'a,
1390{
1391    change_contig_f(tensor, order).map(|v| v.into_owned())
1392}
1393
1394pub fn change_contig<'a, R, T, B, D>(tensor: TensorAny<R, T, B, D>, order: FlagOrder) -> TensorCow<'a, T, B, D>
1395where
1396    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1397    D: DimAPI,
1398    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1399{
1400    change_contig_f(tensor, order).unwrap()
1401}
1402
1403pub fn to_contig<R, T, B, D>(tensor: &TensorAny<R, T, B, D>, order: FlagOrder) -> TensorCow<'_, T, B, D>
1404where
1405    R: DataAPI<Data = B::Raw>,
1406    D: DimAPI,
1407    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1408{
1409    to_contig_f(tensor, order).unwrap()
1410}
1411
1412pub fn into_contig<'a, R, T, B, D>(tensor: TensorAny<R, T, B, D>, order: FlagOrder) -> Tensor<T, B, D>
1413where
1414    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1415    D: DimAPI,
1416    T: Clone,
1417    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1418    B::Raw: Clone + 'a,
1419{
1420    into_contig_f(tensor, order).unwrap()
1421}
1422
1423impl<'a, R, T, B, D> TensorAny<R, T, B, D>
1424where
1425    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1426    D: DimAPI,
1427    T: Clone,
1428    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
1429{
1430    /// Convert tensor to contiguous, with specified layout.
1431    pub fn to_contig(&self, order: FlagOrder) -> TensorCow<'_, T, B, D>
1432    where
1433        B: OpAssignArbitaryAPI<T, D, D>,
1434    {
1435        to_contig(self, order)
1436    }
1437
1438    pub fn to_contig_f(&self, order: FlagOrder) -> Result<TensorCow<'_, T, B, D>>
1439    where
1440        B: OpAssignArbitaryAPI<T, D, D>,
1441    {
1442        to_contig_f(self, order)
1443    }
1444
1445    pub fn into_contig_f(self, order: FlagOrder) -> Result<Tensor<T, B, D>>
1446    where
1447        B: OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1448        B::Raw: Clone + 'a,
1449    {
1450        into_contig_f(self, order)
1451    }
1452
1453    pub fn into_contig(self, order: FlagOrder) -> Tensor<T, B, D>
1454    where
1455        B: OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1456        B::Raw: Clone + 'a,
1457    {
1458        into_contig(self, order)
1459    }
1460
1461    pub fn change_contig_f(self, order: FlagOrder) -> Result<TensorCow<'a, T, B, D>>
1462    where
1463        B: OpAssignArbitaryAPI<T, D, D>,
1464    {
1465        change_contig_f(self, order)
1466    }
1467
1468    pub fn change_contig(self, order: FlagOrder) -> TensorCow<'a, T, B, D>
1469    where
1470        B: OpAssignArbitaryAPI<T, D, D>,
1471    {
1472        change_contig(self, order)
1473    }
1474}
1475
1476/* #endregion */
1477
1478/* #region to_prefer */
1479
1480pub fn change_prefer_f<'a, R, T, B, D>(
1481    tensor: TensorAny<R, T, B, D>,
1482    order: FlagOrder,
1483) -> Result<TensorCow<'a, T, B, D>>
1484where
1485    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1486    D: DimAPI,
1487    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1488{
1489    if (order == FlagOrder::C && tensor.c_prefer()) || (order == FlagOrder::F && tensor.f_prefer()) {
1490        Ok(tensor.into_cow())
1491    } else {
1492        change_contig_f(tensor, order)
1493    }
1494}
1495
1496pub fn to_prefer_f<R, T, B, D>(tensor: &TensorAny<R, T, B, D>, order: FlagOrder) -> Result<TensorCow<'_, T, B, D>>
1497where
1498    R: DataAPI<Data = B::Raw>,
1499    D: DimAPI,
1500    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1501{
1502    change_prefer_f(tensor.view(), order)
1503}
1504
1505pub fn into_prefer_f<'a, R, T, B, D>(tensor: TensorAny<R, T, B, D>, order: FlagOrder) -> Result<Tensor<T, B, D>>
1506where
1507    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1508    D: DimAPI,
1509    T: Clone,
1510    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1511    B::Raw: Clone + 'a,
1512{
1513    change_prefer_f(tensor, order).map(|v| v.into_owned())
1514}
1515
1516pub fn change_prefer<'a, R, T, B, D>(tensor: TensorAny<R, T, B, D>, order: FlagOrder) -> TensorCow<'a, T, B, D>
1517where
1518    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1519    D: DimAPI,
1520    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1521{
1522    change_prefer_f(tensor, order).unwrap()
1523}
1524
1525pub fn to_prefer<R, T, B, D>(tensor: &TensorAny<R, T, B, D>, order: FlagOrder) -> TensorCow<'_, T, B, D>
1526where
1527    R: DataAPI<Data = B::Raw>,
1528    D: DimAPI,
1529    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D>,
1530{
1531    to_prefer_f(tensor, order).unwrap()
1532}
1533
1534pub fn into_prefer<'a, R, T, B, D>(tensor: TensorAny<R, T, B, D>, order: FlagOrder) -> Tensor<T, B, D>
1535where
1536    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1537    D: DimAPI,
1538    T: Clone,
1539    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1540    B::Raw: Clone + 'a,
1541{
1542    into_prefer_f(tensor, order).unwrap()
1543}
1544
1545impl<'a, R, T, B, D> TensorAny<R, T, B, D>
1546where
1547    R: DataAPI<Data = B::Raw> + DataIntoCowAPI<'a>,
1548    D: DimAPI,
1549    T: Clone,
1550    B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
1551{
1552    /// Convert tensor to contiguous, with specified layout.
1553    pub fn to_prefer(&self, order: FlagOrder) -> TensorCow<'_, T, B, D>
1554    where
1555        B: OpAssignArbitaryAPI<T, D, D>,
1556    {
1557        to_prefer(self, order)
1558    }
1559
1560    pub fn to_prefer_f(&self, order: FlagOrder) -> Result<TensorCow<'_, T, B, D>>
1561    where
1562        B: OpAssignArbitaryAPI<T, D, D>,
1563    {
1564        to_prefer_f(self, order)
1565    }
1566
1567    pub fn into_prefer_f(self, order: FlagOrder) -> Result<Tensor<T, B, D>>
1568    where
1569        B: OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1570        B::Raw: Clone + 'a,
1571    {
1572        into_prefer_f(self, order)
1573    }
1574
1575    pub fn into_prefer(self, order: FlagOrder) -> Tensor<T, B, D>
1576    where
1577        B: OpAssignArbitaryAPI<T, D, D> + OpAssignAPI<T, D>,
1578        B::Raw: Clone + 'a,
1579    {
1580        into_prefer(self, order)
1581    }
1582
1583    pub fn change_prefer_f(self, order: FlagOrder) -> Result<TensorCow<'a, T, B, D>>
1584    where
1585        B: OpAssignArbitaryAPI<T, D, D>,
1586    {
1587        change_prefer_f(self, order)
1588    }
1589
1590    pub fn change_prefer(self, order: FlagOrder) -> TensorCow<'a, T, B, D>
1591    where
1592        B: OpAssignArbitaryAPI<T, D, D>,
1593    {
1594        change_prefer(self, order)
1595    }
1596}
1597
1598/* #endregion */
1599
1600#[cfg(test)]
1601mod test_reshape {
1602    use super::*;
1603
1604    #[test]
1605    fn test_playground() {
1606        #[cfg(not(feature = "col_major"))]
1607        {
1608            let a1 = linspace((1.0, 24.0, 24));
1609            let a2 = a1.to_shape([2, 3, 4]);
1610            let default_order = a1.device().default_order();
1611            println!("{a2:?}");
1612            println!("{:?}", core::ptr::eq(a1.as_ptr(), a2.as_ptr()));
1613
1614            let v = layout_reshapeable(a1.layout(), &vec![2, 3, 4], default_order).unwrap();
1615            println!("{v:?}");
1616
1617            let b1 = linspace((1.0, 24.0, 24)).into_layout(vec![2, 3, 4].f());
1618            let b2 = b1.to_shape([24]);
1619            println!("{b2:?}");
1620            println!("{:?}", core::ptr::eq(b1.as_ptr(), b2.as_ptr()));
1621
1622            let v = layout_reshapeable(b1.layout(), &vec![24], default_order).unwrap();
1623            println!("{v:?}");
1624        }
1625        #[cfg(feature = "col_major")]
1626        {
1627            let a1 = linspace((1.0, 24.0, 24));
1628            let a2 = a1.to_shape([2, 3, 4]);
1629            let default_order = a1.device().default_order();
1630            println!("{a2:?}");
1631            println!("{:?}", core::ptr::eq(a1.as_ptr(), a2.as_ptr()));
1632            println!("a2[:, :, 0] =\n{:}", a2.i((.., .., 0)));
1633            println!("a2[:, :, 1] =\n{:}", a2.i((.., .., 1)));
1634            println!("a2[:, :, 2] =\n{:}", a2.i((.., .., 2)));
1635            println!("a2[:, :, 3] =\n{:}", a2.i((.., .., 3)));
1636
1637            let v = layout_reshapeable(a1.layout(), &vec![2, 3, 4], default_order).unwrap();
1638            println!("{v:?}");
1639
1640            let b1 = linspace((1.0, 24.0, 24)).into_layout(vec![2, 3, 4].f());
1641            let b2 = b1.to_shape([24]);
1642            println!("{b2:?}");
1643            println!("{:?}", core::ptr::eq(b1.as_ptr(), b2.as_ptr()));
1644
1645            let v = layout_reshapeable(b1.layout(), &vec![24], default_order).unwrap();
1646            println!("{v:?}");
1647        }
1648    }
1649
1650    #[test]
1651    fn test_contig() {
1652        #[cfg(not(feature = "col_major"))]
1653        {
1654            let layout_in = vec![2, 3, 4].c();
1655            let default_order = RowMajor;
1656            let layout_out = layout_reshapeable(&layout_in, &vec![2, 3, 4], default_order).unwrap();
1657            assert_eq!(layout_out.unwrap(), vec![2, 3, 4].c());
1658
1659            let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 4], default_order).unwrap();
1660            assert_eq!(layout_out.unwrap(), vec![3, 2, 4].c());
1661
1662            let layout_out = layout_reshapeable(&layout_in, &vec![1, 4, 1, 6], default_order).unwrap();
1663            assert_eq!(layout_out.unwrap(), vec![1, 4, 1, 6].c());
1664        }
1665        #[cfg(feature = "col_major")]
1666        {
1667            let layout_in = vec![2, 3, 4].f();
1668            let default_order = ColMajor;
1669            let layout_out = layout_reshapeable(&layout_in, &vec![2, 3, 4], default_order).unwrap();
1670            assert_eq!(layout_out.unwrap(), vec![2, 3, 4].f());
1671
1672            let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 4], default_order).unwrap();
1673            assert_eq!(layout_out.unwrap(), vec![3, 2, 4].f());
1674
1675            let layout_out = layout_reshapeable(&layout_in, &vec![1, 4, 1, 6], default_order).unwrap();
1676            assert_eq!(layout_out.unwrap(), vec![1, 4, 1, 6].f());
1677        }
1678    }
1679
1680    #[test]
1681    fn test_partial_contig() {
1682        #[cfg(not(feature = "col_major"))]
1683        {
1684            // np.zeros(12, 15, 18); a[3:, :, ::3]
1685            // this case is actually contiguous, but with stride 3
1686            let layout_in = Layout::new(vec![9, 15, 6], vec![270, 18, 3], 810).unwrap();
1687            let default_order = RowMajor;
1688
1689            let layout_out = layout_reshapeable(&layout_in, &vec![15, 9, 2, 3], default_order).unwrap();
1690            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![15, 9, 2, 3]);
1691            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![162, 18, 9, 3]);
1692            assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
1693
1694            let layout_out = layout_reshapeable(&layout_in, &vec![10, 27, 3], default_order).unwrap();
1695            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![10, 27, 3]);
1696            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![243, 9, 3]);
1697            assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
1698
1699            // insert some new axes
1700            let layout_out = layout_reshapeable(&layout_in, &vec![1, 10, 1, 27, 3], default_order).unwrap();
1701            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![1, 10, 1, 27, 3]);
1702            // strides follows c-contiguous, but zero is also valid for broadcast
1703            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![2430, 243, 243, 9, 3]);
1704
1705            // np.zeros(12, 15, 18); a[3:, :, 3:15:2]
1706            // this case is not contiguous in last two dimensions
1707            let layout_in = Layout::new(vec![9, 15, 6], vec![270, 18, 2], 813).unwrap();
1708
1709            let layout_out = layout_reshapeable(&layout_in, &vec![15, 9, 2, 3], default_order).unwrap();
1710            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![15, 9, 2, 3]);
1711            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![162, 18, 6, 2]);
1712            assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
1713
1714            let layout_out = layout_reshapeable(&layout_in, &vec![10, 27, 3], default_order).unwrap();
1715            assert!(layout_out.is_none());
1716        }
1717        #[cfg(feature = "col_major")]
1718        {
1719            let layout_in = Layout::new(vec![6, 15, 9], vec![3, 18, 270], 810).unwrap();
1720            let default_order = ColMajor;
1721
1722            let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 9, 15], default_order).unwrap();
1723            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 2, 9, 15]);
1724            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![3, 9, 18, 162]);
1725            assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
1726
1727            let layout_out = layout_reshapeable(&layout_in, &vec![3, 27, 10], default_order).unwrap();
1728            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 27, 10]);
1729            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![3, 9, 243]);
1730            assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
1731
1732            // insert some new axes
1733            let layout_out = layout_reshapeable(&layout_in, &vec![3, 27, 1, 10, 1], default_order).unwrap();
1734            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 27, 1, 10, 1]);
1735            // strides follows f-contiguous, but zero is also valid for broadcast
1736            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![3, 9, 243, 243, 2430]);
1737
1738            // np.zeros(12, 15, 18); a[3:, :, 3:15:2]
1739            // this case is not contiguous in last two dimensions
1740            let layout_in = Layout::new(vec![6, 15, 9], vec![2, 18, 270], 813).unwrap();
1741
1742            let layout_out = layout_reshapeable(&layout_in, &vec![3, 2, 9, 15], default_order).unwrap();
1743            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 2, 9, 15]);
1744            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![2, 6, 18, 162]);
1745            assert_eq!(layout_out.as_ref().unwrap().offset(), layout_in.offset());
1746
1747            let layout_out = layout_reshapeable(&layout_in, &vec![10, 27, 3], default_order).unwrap();
1748            assert!(layout_out.is_none());
1749        }
1750    }
1751
1752    #[test]
1753    fn test_minus_stride() {
1754        #[cfg(not(feature = "col_major"))]
1755        {
1756            // np.zeros(12, 15, 18); a[3:, ::-1, ::-3]
1757            // this case should be seen contiguous in last two dimensions
1758            let layout_in = Layout::new(vec![9, 15, 6], vec![270, -18, -3], 1079).unwrap();
1759            let default_order = RowMajor;
1760
1761            let layout_out = layout_reshapeable(&layout_in, &vec![15, 9, 2, 3], default_order).unwrap();
1762            assert!(layout_out.is_none());
1763
1764            let layout_out = layout_reshapeable(&layout_in, &vec![3, 3, 10, 9], default_order).unwrap();
1765            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![3, 3, 10, 9]);
1766            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![810, 270, -27, -3]);
1767        }
1768    }
1769
1770    #[test]
1771    fn test_broadcast_reshape() {
1772        #[cfg(not(feature = "col_major"))]
1773        {
1774            // a = np.zeros(12, 15, 18);
1775            // b = np.broadcast_to(a[:, None], (12, 16, 15, 18))
1776            let layout_in = unsafe { Layout::new_unchecked(vec![12, 16, 15, 18], vec![270, 0, 18, 1], 0) };
1777            let default_order = RowMajor;
1778
1779            let layout_out = layout_reshapeable(&layout_in, &vec![4, 3, 4, 4, 9, 1, 30], default_order).unwrap();
1780            assert_eq!(layout_out.as_ref().unwrap().shape(), &vec![4, 3, 4, 4, 9, 1, 30]);
1781            assert_eq!(layout_out.as_ref().unwrap().stride(), &vec![810, 270, 0, 0, 30, 30, 1]);
1782
1783            let layout_out = layout_reshapeable(&layout_in, &vec![16, 12, 15, 18], default_order).unwrap();
1784            assert!(layout_out.is_none());
1785        }
1786    }
1787}
1788
1789#[cfg(test)]
1790mod tests {
1791    use super::*;
1792    use crate::Tensor;
1793
1794    #[test]
1795    fn test_to_shape_assume_contig() {
1796        let a = linspace((2.5, 3.2, 16));
1797        let b = a.to_shape_assume_contig_f([4, 4]).unwrap();
1798        println!("{b:.3?}");
1799    }
1800
1801    #[test]
1802    fn test_expand_dims() {
1803        let a: Tensor<f64, _> = zeros([4, 9, 8]);
1804        let b = a.expand_dims(2);
1805        assert_eq!(b.shape(), &[4, 9, 1, 8]);
1806        let b = a.expand_dims([1, 3]);
1807        assert_eq!(b.shape(), &[4, 1, 9, 8, 1]);
1808        let b = a.expand_dims([1, -1]);
1809        assert_eq!(b.shape(), &[4, 1, 9, 8, 1]);
1810        let b = a.expand_dims([-1, -4, 1, 0]);
1811        assert_eq!(b.shape(), &[1, 1, 4, 1, 9, 8, 1]);
1812    }
1813
1814    #[test]
1815    fn test_squeeze() {
1816        let a: Tensor<f64, _> = zeros([4, 1, 9, 1, 8, 1]);
1817        let b = a.squeeze(3);
1818        assert_eq!(b.shape(), &[4, 1, 9, 8, 1]);
1819        let b = a.squeeze([1, 3]);
1820        assert_eq!(b.shape(), &[4, 9, 8, 1]);
1821        let b = a.squeeze([1, -1]);
1822        assert_eq!(b.shape(), &[4, 9, 1, 8]);
1823        let b = a.squeeze_f(-7);
1824        assert!(b.is_err());
1825    }
1826
1827    #[test]
1828    fn test_flip() {
1829        let a = arange(24.0).into_shape([2, 3, 4]).into_owned();
1830        println!("{a:?}");
1831
1832        let b = a.flip(1);
1833        println!("{b:?}");
1834        assert_eq!(b.shape(), &[2, 3, 4]);
1835        let c = a.flip([0, -1]);
1836        println!("{c:?}");
1837        assert_eq!(c.shape(), &[2, 3, 4]);
1838    }
1839
1840    #[test]
1841    fn test_swapaxes() {
1842        let a = arange(24.0).into_shape([2, 3, 4]).into_owned();
1843        println!("{a:?}");
1844
1845        let b = a.swapaxes(0, 1);
1846        println!("{b:?}");
1847        assert_eq!(b.shape(), &[3, 2, 4]);
1848    }
1849
1850    #[test]
1851    fn test_to_shape() {
1852        let a = linspace((0.0, 15.0, 16));
1853        let mut a = a.to_shape([4, 4]);
1854        a.layout = Layout::new(vec![2, 2], vec![2, 4], 0).unwrap();
1855        println!("{a:?}");
1856        let b = a.to_shape([2, 2]);
1857        println!("{b:?}");
1858
1859        let c = a.to_shape([2, -1]);
1860        println!("{c:?}");
1861        assert_eq!(c.shape(), &[2, 2]);
1862
1863        let d = a.to_shape_f([3, -1]);
1864        assert!(d.is_err());
1865    }
1866
1867    #[test]
1868    fn test_broadcast_to() {
1869        #[cfg(not(feature = "col_major"))]
1870        {
1871            let a = linspace((0.0, 15.0, 16));
1872            let a = a.into_shape_assume_contig_f([4, 1, 4]).unwrap();
1873            let a = a.to_broadcast_f([6, 4, 3, 4]).unwrap();
1874            println!("{a:?}");
1875            assert_eq!(a.layout(), unsafe { &Layout::new_unchecked([6, 4, 3, 4], [0, 4, 0, 1], 0) });
1876        }
1877        #[cfg(feature = "col_major")]
1878        {
1879            let a = linspace((0.0, 15.0, 16));
1880            let a = a.into_shape_assume_contig_f([4, 1, 4]).unwrap();
1881            let a = a.to_broadcast_f([4, 3, 4, 6]).unwrap();
1882            println!("{a:?}");
1883            assert_eq!(a.layout(), unsafe { &Layout::new_unchecked([4, 3, 4, 6], [1, 0, 4, 0], 0) });
1884        }
1885    }
1886
1887    #[test]
1888    fn test_to_layout() {
1889        let a = linspace((0.0, 15.0, 16));
1890        let a = a.change_shape([4, 4]);
1891        let a = a.into_layout(Layout::new([2, 8], [12, 120], 8).unwrap());
1892        println!("{a:?}");
1893    }
1894}