rstsr_core/tensor/
creation_from_tensor.rs

1//! Creation methods for `Tensor` struct from other tensors.
2//!
3//! Todo list:
4//! - [ ] `diag`
5//! - [ ] `tril`
6//! - [ ] `triu`
7
8use core::mem::transmute;
9
10use crate::prelude_dev::*;
11
12/* #region diag */
13
14pub trait DiagAPI<Inp> {
15    type Out;
16
17    fn diag_f(self) -> Result<Self::Out>;
18    fn diag(self) -> Self::Out
19    where
20        Self: Sized,
21    {
22        Self::diag_f(self).rstsr_unwrap()
23    }
24}
25
26/// Extract a diagonal or construct a diagonal tensor.
27///
28/// - If input is a 2-D tensor, return a copy of its diagonal (with offset).
29/// - If input is a 1-D tensor, construct a 2-D tensor with the input as its diagonal.
30///
31/// # See also
32///
33/// - [numpy.diag](https://numpy.org/doc/stable/reference/generated/numpy.diag.html)
34pub fn diag<Args, Inp>(param: Args) -> Args::Out
35where
36    Args: DiagAPI<Inp>,
37{
38    Args::diag(param)
39}
40
41pub fn diag_f<Args, Inp>(param: Args) -> Result<Args::Out>
42where
43    Args: DiagAPI<Inp>,
44{
45    Args::diag_f(param)
46}
47
48impl<R, T, B, D> DiagAPI<()> for (&TensorAny<R, T, B, D>, isize)
49where
50    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
51    T: Clone + Default,
52    D: DimAPI,
53    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
54{
55    type Out = Tensor<T, B, IxD>;
56
57    fn diag_f(self) -> Result<Self::Out> {
58        let (tensor, offset) = self;
59        if tensor.ndim() == 1 {
60            let layout_diag = tensor.layout().to_dim::<Ix1>()?;
61            let n_row = tensor.size() + offset.unsigned_abs();
62            let mut result = full_f(([n_row, n_row], T::default(), tensor.device()))?;
63            let layout_result = result.layout().diagonal(Some(offset), Some(0), Some(1))?;
64            let device = tensor.device();
65            device.assign(result.raw_mut(), &layout_result.to_dim()?, tensor.raw(), &layout_diag)?;
66            return Ok(result);
67        } else if tensor.ndim() == 2 {
68            let layout = tensor.layout().to_dim::<Ix2>()?;
69            let layout_diag = layout.diagonal(Some(offset), Some(0), Some(1))?;
70            let size = layout_diag.size();
71            let device = tensor.device();
72            let mut result = unsafe { empty_f(([size], device))? };
73            let layout_result = result.layout().to_dim()?;
74            device.assign(result.raw_mut(), &layout_result, tensor.raw(), &layout_diag)?;
75            return Ok(result);
76        } else {
77            return rstsr_raise!(InvalidLayout, "diag only support 1-D or 2-D tensor.");
78        }
79    }
80}
81
82impl<R, T, B, D> DiagAPI<()> for &TensorAny<R, T, B, D>
83where
84    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
85    T: Clone + Default,
86    D: DimAPI,
87    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
88{
89    type Out = Tensor<T, B, IxD>;
90
91    fn diag_f(self) -> Result<Self::Out> {
92        return diag_f((self, 0));
93    }
94}
95
96/* #endregion */
97
98/* #region meshgrid */
99
100pub trait MeshgridAPI<Inp> {
101    type Out;
102
103    fn meshgrid_f(self) -> Result<Self::Out>;
104    fn meshgrid(self) -> Self::Out
105    where
106        Self: Sized,
107    {
108        Self::meshgrid_f(self).rstsr_unwrap()
109    }
110}
111
112/// Returns coordinate matrices from coordinate vectors.
113///
114/// # See also
115///
116/// - [Python Array Standard `meshgrid`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.meshgrid.html)
117pub fn meshgrid<Args, Inp>(args: Args) -> Args::Out
118where
119    Args: MeshgridAPI<Inp>,
120{
121    Args::meshgrid(args)
122}
123
124pub fn meshgrid_f<Args, Inp>(args: Args) -> Result<Args::Out>
125where
126    Args: MeshgridAPI<Inp>,
127{
128    Args::meshgrid_f(args)
129}
130
131impl<R, T, B, D> MeshgridAPI<()> for (Vec<&TensorAny<R, T, B, D>>, &str, bool)
132where
133    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataCloneAPI,
134    T: Clone,
135    D: DimAPI,
136    B: DeviceAPI<T>
137        + DeviceRawAPI<MaybeUninit<T>>
138        + DeviceCreationAnyAPI<T>
139        + OpAssignAPI<T, IxD>
140        + OpAssignArbitaryAPI<T, IxD, IxD>,
141    <B as DeviceRawAPI<T>>::Raw: Clone,
142{
143    type Out = Vec<Tensor<T, B, IxD>>;
144
145    fn meshgrid_f(self) -> Result<Self::Out> {
146        let (tensors, indexing, copy) = self;
147
148        match indexing {
149            "ij" | "xy" => (),
150            _ => rstsr_raise!(InvalidValue, "indexing must be 'ij' or 'xy'.")?,
151        }
152
153        // fast return for tensors with length 0/1
154        if tensors.is_empty() {
155            return Ok(vec![]);
156        } else if tensors.len() == 1 {
157            let tensor = tensors[0];
158            rstsr_assert_eq!(tensor.ndim(), 1, InvalidLayout, "meshgrid only support 1-D tensor.")?;
159            return Ok(vec![tensor.view().into_dim().into_owned()]);
160        }
161
162        // check
163        // a. all tensors must have the same device
164        // b. all tensors are 1-D
165        let device = tensors[0].device();
166        tensors.iter().try_for_each(|tensor| -> Result<()> {
167            rstsr_assert_eq!(tensor.ndim(), 1, InvalidLayout, "meshgrid only support 1-D tensor.")?;
168            rstsr_assert!(
169                tensor.device().same_device(device),
170                DeviceMismatch,
171                "All tensors must be on the same device."
172            )?;
173            Ok(())
174        })?;
175
176        let ndim = tensors.len();
177        let s0 = vec![1isize; ndim];
178
179        // tensors to be broadcasted
180        let tensors = tensors
181            .iter()
182            .enumerate()
183            .map(|(i, tensor)| {
184                let mut shape_new = s0.clone();
185                if indexing == "xy" && i == 0 {
186                    // special case for indexing="xy"
187                    shape_new[1] = -1;
188                } else if indexing == "xy" && i == 1 {
189                    // special case for indexing="xy"
190                    shape_new[0] = -1;
191                } else {
192                    // s0[:i] + (-1,) + so[i+1:]
193                    shape_new[i] = -1;
194                }
195                tensor.view().into_dim::<IxD>().into_shape_f(shape_new)
196            })
197            .collect::<Result<Vec<_>>>()?;
198        // tensors have been broadcasted to the same shape
199        let tensors = broadcast_arrays_f(tensors)?;
200
201        if !copy {
202            Ok(tensors)
203        } else {
204            tensors.into_iter().map(|t| t.into_contig_f(device.default_order())).collect()
205        }
206    }
207}
208
209// implementation for reference tensors
210#[duplicate_item(
211    ImplType         ImplStruct                                           tuple_args                  tuple_internal                             ;
212   [              ] [(&Vec<&TensorAny<R, T, B, D>>, &str, bool)] [(tensors, indexing, copy)] [(tensors.to_vec(), indexing, copy)];
213   [const N: usize] [([&TensorAny<R, T, B, D>; N] , &str, bool)] [(tensors, indexing, copy)] [(tensors.to_vec(), indexing, copy)];
214   [              ] [(Vec<&TensorAny<R, T, B, D>> , &str,     )] [(tensors, indexing,     )] [(tensors.to_vec(), indexing, true)];
215   [              ] [(&Vec<&TensorAny<R, T, B, D>>, &str,     )] [(tensors, indexing,     )] [(tensors.to_vec(), indexing, true)];
216   [const N: usize] [([&TensorAny<R, T, B, D>; N] , &str,     )] [(tensors, indexing,     )] [(tensors.to_vec(), indexing, true)];
217   [              ] [(Vec<&TensorAny<R, T, B, D>> ,       bool)] [(tensors,           copy)] [(tensors.to_vec(), "xy"    , copy)];
218   [              ] [(&Vec<&TensorAny<R, T, B, D>>,       bool)] [(tensors,           copy)] [(tensors.to_vec(), "xy"    , copy)];
219   [const N: usize] [([&TensorAny<R, T, B, D>; N] ,       bool)] [(tensors,           copy)] [(tensors.to_vec(), "xy"    , copy)];
220   [              ] [ Vec<&TensorAny<R, T, B, D>>              ] [ tensors                 ] [(tensors.to_vec(), "xy"    , true)];
221   [              ] [ &Vec<&TensorAny<R, T, B, D>>             ] [ tensors                 ] [(tensors.to_vec(), "xy"    , true)];
222   [const N: usize] [ [&TensorAny<R, T, B, D>; N]              ] [ tensors                 ] [(tensors.to_vec(), "xy"    , true)];
223)]
224impl<R, T, B, D, ImplType> MeshgridAPI<()> for ImplStruct
225where
226    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataCloneAPI,
227    T: Clone,
228    D: DimAPI,
229    B: DeviceAPI<T>
230        + DeviceRawAPI<MaybeUninit<T>>
231        + DeviceCreationAnyAPI<T>
232        + OpAssignAPI<T, IxD>
233        + OpAssignArbitaryAPI<T, IxD, IxD>,
234    <B as DeviceRawAPI<T>>::Raw: Clone,
235{
236    type Out = Vec<Tensor<T, B, IxD>>;
237
238    fn meshgrid_f(self) -> Result<Self::Out> {
239        let tuple_args = self;
240        let (tensors, indexing, copy) = tuple_internal;
241        MeshgridAPI::meshgrid_f((tensors, indexing, copy))
242    }
243}
244
245// implementation for non-reference tensors
246#[duplicate_item(
247    ImplType         ImplStruct                                           tuple_args           tuple_internal ;
248   [              ] [(Vec<TensorAny<R, T, B, D>> , &str, bool)] [(tensors, indexing, copy)] [(indexing, copy)];
249   [              ] [(&Vec<TensorAny<R, T, B, D>>, &str, bool)] [(tensors, indexing, copy)] [(indexing, copy)];
250   [const N: usize] [([TensorAny<R, T, B, D>; N] , &str, bool)] [(tensors, indexing, copy)] [(indexing, copy)];
251   [              ] [(Vec<TensorAny<R, T, B, D>> , &str,     )] [(tensors, indexing,     )] [(indexing, true)];
252   [              ] [(&Vec<TensorAny<R, T, B, D>>, &str,     )] [(tensors, indexing,     )] [(indexing, true)];
253   [const N: usize] [([TensorAny<R, T, B, D>; N] , &str,     )] [(tensors, indexing,     )] [(indexing, true)];
254   [              ] [(Vec<TensorAny<R, T, B, D>> ,       bool)] [(tensors,           copy)] [("xy"    , copy)];
255   [              ] [(&Vec<TensorAny<R, T, B, D>>,       bool)] [(tensors,           copy)] [("xy"    , copy)];
256   [const N: usize] [([TensorAny<R, T, B, D>; N] ,       bool)] [(tensors,           copy)] [("xy"    , copy)];
257   [              ] [ Vec<TensorAny<R, T, B, D>>              ] [ tensors                 ] [("xy"    , true)];
258   [              ] [ &Vec<TensorAny<R, T, B, D>>             ] [ tensors                 ] [("xy"    , true)];
259   [const N: usize] [ [TensorAny<R, T, B, D>; N]              ] [ tensors                 ] [("xy"    , true)];
260)]
261impl<R, T, B, D, ImplType> MeshgridAPI<()> for ImplStruct
262where
263    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataCloneAPI,
264    T: Clone,
265    D: DimAPI,
266    B: DeviceAPI<T>
267        + DeviceRawAPI<MaybeUninit<T>>
268        + DeviceCreationAnyAPI<T>
269        + OpAssignAPI<T, IxD>
270        + OpAssignArbitaryAPI<T, IxD, IxD>,
271    <B as DeviceRawAPI<T>>::Raw: Clone,
272{
273    type Out = Vec<Tensor<T, B, IxD>>;
274
275    fn meshgrid_f(self) -> Result<Self::Out> {
276        let tuple_args = self;
277        let (indexing, copy) = tuple_internal;
278        let tensors = tensors.iter().collect::<Vec<_>>();
279        MeshgridAPI::meshgrid_f((tensors, indexing, copy))
280    }
281}
282
283/* #endregion */
284
285/* #region concat */
286
287pub trait ConcatAPI<Inp> {
288    type Out;
289
290    fn concat_f(self) -> Result<Self::Out>;
291    fn concat(self) -> Self::Out
292    where
293        Self: Sized,
294    {
295        Self::concat_f(self).rstsr_unwrap()
296    }
297}
298
299/// Join a sequence of arrays along an existing axis.
300///
301/// # See also
302///
303/// - [Python Array Standard `concatnate`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.concat.html)
304pub fn concat<Args, Inp>(args: Args) -> Args::Out
305where
306    Args: ConcatAPI<Inp>,
307{
308    Args::concat(args)
309}
310
311pub fn concat_f<Args, Inp>(args: Args) -> Result<Args::Out>
312where
313    Args: ConcatAPI<Inp>,
314{
315    Args::concat_f(args)
316}
317
318pub use concat as concatenate;
319pub use concat_f as concatenate_f;
320
321impl<R, T, B, D> ConcatAPI<()> for (Vec<TensorAny<R, T, B, D>>, isize)
322where
323    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
324    T: Clone + Default,
325    D: DimAPI,
326    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
327{
328    type Out = Tensor<T, B, IxD>;
329
330    fn concat_f(self) -> Result<Self::Out> {
331        let (tensors, axis) = self;
332
333        // quick error for empty tensors
334        rstsr_assert!(!tensors.is_empty(), InvalidValue, "concat requires at least one tensor.")?;
335
336        // check same device and same ndim
337        let device = tensors[0].device().clone();
338        let ndim = tensors[0].ndim();
339
340        rstsr_assert!(ndim > 0, InvalidLayout, "All tensors must have ndim > 0 in concat.")?;
341        tensors.iter().try_for_each(|tensor| -> Result<()> {
342            rstsr_assert_eq!(tensor.ndim(), ndim, InvalidLayout, "All tensors must have the same ndim.")?;
343            rstsr_assert!(
344                tensor.device().same_device(&device),
345                DeviceMismatch,
346                "All tensors must be on the same device."
347            )?;
348            Ok(())
349        })?;
350
351        // check and make axis positive
352        let axis = if axis < 0 { ndim as isize + axis } else { axis };
353        rstsr_pattern!(axis, 0..ndim as isize, InvalidLayout, "axis out of bounds")?;
354        let axis = axis as usize;
355
356        // - check shape compatibility (dimension other than axis must match)
357        // - calculate the new shape
358        let mut new_axis_size = 0;
359        let mut shape_other = tensors[0].shape().as_ref().to_vec();
360        shape_other.remove(axis);
361        for tensor in &tensors {
362            let mut shape_other_i = tensor.shape().as_ref().to_vec();
363            new_axis_size += shape_other_i.remove(axis);
364            rstsr_assert_eq!(
365                shape_other_i,
366                shape_other,
367                InvalidLayout,
368                "All tensors must have the same shape except for the concatenation axis."
369            )?;
370        }
371        shape_other.insert(axis, new_axis_size);
372        let new_shape = shape_other;
373
374        // create the result tensor
375        let mut result = unsafe { empty_f((new_shape, &device))? };
376
377        // assign each tensor to the result tensor
378        let mut offset = 0;
379        for tensor in tensors {
380            let layout = tensor.layout().to_dim::<IxD>()?;
381            let axis_size = tensor.shape()[axis];
382            let layout_result = result.layout().dim_narrow(axis as isize, slice!(offset, offset + axis_size))?;
383            device.assign(result.raw_mut(), &layout_result, tensor.raw(), &layout)?;
384            offset += axis_size;
385        }
386
387        Ok(result)
388    }
389}
390
391#[duplicate_item(
392    ImplType         ImplStruct                            ;
393   [              ] [(&Vec<TensorAny<R, T, B, D>> , isize)];
394   [const N: usize] [([TensorAny<R, T, B, D>; N]  , isize)];
395   [              ] [(Vec<TensorAny<R, T, B, D>>  , usize)];
396   [              ] [(&Vec<TensorAny<R, T, B, D>> , usize)];
397   [const N: usize] [([TensorAny<R, T, B, D>; N]  , usize)];
398   [              ] [(Vec<TensorAny<R, T, B, D>>  , i32  )];
399   [              ] [(&Vec<TensorAny<R, T, B, D>> , i32  )];
400   [const N: usize] [([TensorAny<R, T, B, D>; N]  , i32  )];
401   [              ] [(Vec<&TensorAny<R, T, B, D>> , isize)];
402   [              ] [(&Vec<&TensorAny<R, T, B, D>>, isize)];
403   [const N: usize] [([&TensorAny<R, T, B, D>; N] , isize)];
404   [              ] [(Vec<&TensorAny<R, T, B, D>> , usize)];
405   [              ] [(&Vec<&TensorAny<R, T, B, D>>, usize)];
406   [const N: usize] [([&TensorAny<R, T, B, D>; N] , usize)];
407   [              ] [(Vec<&TensorAny<R, T, B, D>> , i32  )];
408   [              ] [(&Vec<&TensorAny<R, T, B, D>>, i32  )];
409   [const N: usize] [([&TensorAny<R, T, B, D>; N] , i32  )];
410)]
411impl<R, T, B, D, ImplType> ConcatAPI<()> for ImplStruct
412where
413    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
414    T: Clone + Default,
415    D: DimAPI,
416    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
417{
418    type Out = Tensor<T, B, IxD>;
419
420    fn concat_f(self) -> Result<Self::Out> {
421        let (tensors, axis) = self;
422        #[allow(clippy::unnecessary_cast)]
423        let axis = axis as isize;
424        let tensors = tensors.iter().map(|t| t.view()).collect::<Vec<_>>();
425        ConcatAPI::concat_f((tensors, axis))
426    }
427}
428
429#[duplicate_item(
430    ImplType         ImplStruct                   ;
431   [              ] [Vec<TensorAny<R, T, B, D>>  ];
432   [              ] [&Vec<TensorAny<R, T, B, D>> ];
433   [const N: usize] [[TensorAny<R, T, B, D>; N]  ];
434   [              ] [Vec<&TensorAny<R, T, B, D>> ];
435   [              ] [&Vec<&TensorAny<R, T, B, D>>];
436   [const N: usize] [[&TensorAny<R, T, B, D>; N] ];
437)]
438impl<R, T, B, D, ImplType> ConcatAPI<()> for ImplStruct
439where
440    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
441    T: Clone + Default,
442    D: DimAPI,
443    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
444{
445    type Out = Tensor<T, B, IxD>;
446
447    fn concat_f(self) -> Result<Self::Out> {
448        let tensors = self;
449        #[allow(clippy::unnecessary_cast)]
450        let axis = 0;
451        let tensors = tensors.iter().map(|t| t.view()).collect::<Vec<_>>();
452        ConcatAPI::concat_f((tensors, axis))
453    }
454}
455
456/* #endregion */
457
458/* #region hstack */
459
460pub trait HStackAPI<Inp> {
461    type Out;
462
463    fn hstack_f(self) -> Result<Self::Out>;
464    fn hstack(self) -> Self::Out
465    where
466        Self: Sized,
467    {
468        Self::hstack_f(self).rstsr_unwrap()
469    }
470}
471
472/// Stack tensors in sequence horizontally (column-wise).
473///
474/// # See also
475///
476/// [NumPy `hstack`](https://numpy.org/doc/stable/reference/generated/numpy.hstack.html)
477pub fn hstack<Args, Inp>(args: Args) -> Args::Out
478where
479    Args: HStackAPI<Inp>,
480{
481    Args::hstack(args)
482}
483
484pub fn hstack_f<Args, Inp>(args: Args) -> Result<Args::Out>
485where
486    Args: HStackAPI<Inp>,
487{
488    Args::hstack_f(args)
489}
490
491#[duplicate_item(
492    ImplType         ImplStruct                   ;
493   [              ] [Vec<TensorAny<R, T, B, D>>  ];
494   [              ] [&Vec<TensorAny<R, T, B, D>> ];
495   [const N: usize] [[TensorAny<R, T, B, D>; N]  ];
496   [              ] [Vec<&TensorAny<R, T, B, D>> ];
497   [              ] [&Vec<&TensorAny<R, T, B, D>>];
498   [const N: usize] [[&TensorAny<R, T, B, D>; N] ];
499)]
500impl<R, T, B, D, ImplType> HStackAPI<()> for ImplStruct
501where
502    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
503    T: Clone + Default,
504    D: DimAPI,
505    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
506{
507    type Out = Tensor<T, B, IxD>;
508
509    fn hstack_f(self) -> Result<Self::Out> {
510        let tensors = self;
511
512        if tensors.is_empty() {
513            return rstsr_raise!(InvalidValue, "hstack requires at least one tensor.");
514        }
515
516        if tensors[0].ndim() == 1 {
517            ConcatAPI::concat_f((tensors, 0))
518        } else {
519            ConcatAPI::concat_f((tensors, 1))
520        }
521    }
522}
523
524/* #endregion */
525
526/* #region vstack */
527
528pub trait VStackAPI<Inp> {
529    type Out;
530
531    fn vstack_f(self) -> Result<Self::Out>;
532    fn vstack(self) -> Self::Out
533    where
534        Self: Sized,
535    {
536        Self::vstack_f(self).rstsr_unwrap()
537    }
538}
539
540/// Stack tensors in sequence horizontally (row-wise).
541///
542/// # See also
543///
544/// [NumPy `vstack`](https://numpy.org/doc/stable/reference/generated/numpy.vstack.html)
545pub fn vstack<Args, Inp>(args: Args) -> Args::Out
546where
547    Args: VStackAPI<Inp>,
548{
549    Args::vstack(args)
550}
551
552pub fn vstack_f<Args, Inp>(args: Args) -> Result<Args::Out>
553where
554    Args: VStackAPI<Inp>,
555{
556    Args::vstack_f(args)
557}
558
559#[duplicate_item(
560    ImplType         ImplStruct                   ;
561   [              ] [Vec<TensorAny<R, T, B, D>>  ];
562   [              ] [&Vec<TensorAny<R, T, B, D>> ];
563   [const N: usize] [[TensorAny<R, T, B, D>; N]  ];
564   [              ] [Vec<&TensorAny<R, T, B, D>> ];
565   [              ] [&Vec<&TensorAny<R, T, B, D>>];
566   [const N: usize] [[&TensorAny<R, T, B, D>; N] ];
567)]
568impl<R, T, B, D, ImplType> VStackAPI<()> for ImplStruct
569where
570    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
571    T: Clone + Default,
572    D: DimAPI,
573    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
574{
575    type Out = Tensor<T, B, IxD>;
576
577    fn vstack_f(self) -> Result<Self::Out> {
578        let tensors = self;
579
580        if tensors.is_empty() {
581            return rstsr_raise!(InvalidValue, "vstack requires at least one tensor.");
582        }
583
584        ConcatAPI::concat_f((tensors, 0))
585    }
586}
587
588/* #endregion */
589
590/* #region stack */
591
592pub trait StackAPI<Inp> {
593    type Out;
594
595    fn stack_f(self) -> Result<Self::Out>;
596    fn stack(self) -> Self::Out
597    where
598        Self: Sized,
599    {
600        Self::stack_f(self).rstsr_unwrap()
601    }
602}
603
604/// Joins a sequence of arrays along a new axis.
605///
606/// # See also
607///
608/// [Python Array Standard `stack`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.stack.html)
609pub fn stack<Args, Inp>(args: Args) -> Args::Out
610where
611    Args: StackAPI<Inp>,
612{
613    Args::stack(args)
614}
615
616pub fn stack_f<Args, Inp>(args: Args) -> Result<Args::Out>
617where
618    Args: StackAPI<Inp>,
619{
620    Args::stack_f(args)
621}
622
623impl<R, T, B, D> StackAPI<()> for (Vec<TensorAny<R, T, B, D>>, isize)
624where
625    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
626    T: Clone + Default,
627    D: DimAPI,
628    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
629{
630    type Out = Tensor<T, B, IxD>;
631
632    fn stack_f(self) -> Result<Self::Out> {
633        let (tensors, axis) = self;
634
635        // quick error for empty tensors
636        rstsr_assert!(!tensors.is_empty(), InvalidValue, "stack requires at least one tensor.")?;
637
638        // check same device and same ndim
639        let device = tensors[0].device().clone();
640        let ndim = tensors[0].ndim();
641        let shape_orig = tensors[0].shape();
642
643        rstsr_assert!(ndim > 0, InvalidLayout, "All tensors must have ndim > 0 in stack.")?;
644        tensors.iter().try_for_each(|tensor| -> Result<()> {
645            rstsr_assert_eq!(tensor.shape(), shape_orig, InvalidLayout, "All tensors must have the same shape.")?;
646            rstsr_assert!(
647                tensor.device().same_device(&device),
648                DeviceMismatch,
649                "All tensors must be on the same device."
650            )?;
651            Ok(())
652        })?;
653
654        // check and make axis positive
655        let axis = if axis < 0 { ndim as isize + axis + 1 } else { axis };
656        rstsr_pattern!(axis, 0..=ndim as isize, InvalidLayout, "axis out of bounds")?;
657        let axis = axis as usize;
658
659        // expand the shape of each tensor
660        let tensors = tensors.into_iter().map(|tensor| tensor.into_expand_dims_f(axis)).collect::<Result<Vec<_>>>()?;
661
662        // use concat function to perform the stacking
663        ConcatAPI::concat_f((tensors, axis as isize))
664    }
665}
666
667#[duplicate_item(
668    ImplType         ImplStruct                            ;
669   [              ] [(&Vec<TensorAny<R, T, B, D>> , isize)];
670   [const N: usize] [([TensorAny<R, T, B, D>; N]  , isize)];
671   [              ] [(Vec<TensorAny<R, T, B, D>>  , usize)];
672   [              ] [(&Vec<TensorAny<R, T, B, D>> , usize)];
673   [const N: usize] [([TensorAny<R, T, B, D>; N]  , usize)];
674   [              ] [(Vec<TensorAny<R, T, B, D>>  , i32  )];
675   [              ] [(&Vec<TensorAny<R, T, B, D>> , i32  )];
676   [const N: usize] [([TensorAny<R, T, B, D>; N]  , i32  )];
677   [              ] [(Vec<&TensorAny<R, T, B, D>> , isize)];
678   [              ] [(&Vec<&TensorAny<R, T, B, D>>, isize)];
679   [const N: usize] [([&TensorAny<R, T, B, D>; N] , isize)];
680   [              ] [(Vec<&TensorAny<R, T, B, D>> , usize)];
681   [              ] [(&Vec<&TensorAny<R, T, B, D>>, usize)];
682   [const N: usize] [([&TensorAny<R, T, B, D>; N] , usize)];
683   [              ] [(Vec<&TensorAny<R, T, B, D>> , i32  )];
684   [              ] [(&Vec<&TensorAny<R, T, B, D>>, i32  )];
685   [const N: usize] [([&TensorAny<R, T, B, D>; N] , i32  )];
686)]
687impl<R, T, B, D, ImplType> StackAPI<()> for ImplStruct
688where
689    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
690    T: Clone + Default,
691    D: DimAPI,
692    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
693{
694    type Out = Tensor<T, B, IxD>;
695
696    fn stack_f(self) -> Result<Self::Out> {
697        let (tensors, axis) = self;
698        #[allow(clippy::unnecessary_cast)]
699        let axis = axis as isize;
700        let tensors = tensors.iter().map(|t| t.view()).collect::<Vec<_>>();
701        StackAPI::stack_f((tensors, axis))
702    }
703}
704
705#[duplicate_item(
706    ImplType         ImplStruct                   ;
707   [              ] [Vec<TensorAny<R, T, B, D>>  ];
708   [              ] [&Vec<TensorAny<R, T, B, D>> ];
709   [const N: usize] [[TensorAny<R, T, B, D>; N]  ];
710   [              ] [Vec<&TensorAny<R, T, B, D>> ];
711   [              ] [&Vec<&TensorAny<R, T, B, D>>];
712   [const N: usize] [[&TensorAny<R, T, B, D>; N] ];
713)]
714impl<R, T, B, D, ImplType> StackAPI<()> for ImplStruct
715where
716    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
717    T: Clone + Default,
718    D: DimAPI,
719    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
720{
721    type Out = Tensor<T, B, IxD>;
722
723    fn stack_f(self) -> Result<Self::Out> {
724        let tensors = self;
725        #[allow(clippy::unnecessary_cast)]
726        let axis = 0;
727        let tensors = tensors.iter().map(|t| t.view()).collect::<Vec<_>>();
728        StackAPI::stack_f((tensors, axis))
729    }
730}
731
732/* #endregion */
733
734/* #region unstack */
735
736pub trait UnstackAPI<Inp> {
737    type Out;
738
739    fn unstack_f(self) -> Result<Self::Out>;
740    fn unstack(self) -> Self::Out
741    where
742        Self: Sized,
743    {
744        Self::unstack_f(self).rstsr_unwrap()
745    }
746}
747
748/// Splits an array into a sequence of arrays along the given axis.
749///
750/// # See also
751///
752/// [Python Array Standard `unstack`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.unstack.html)
753pub fn unstack<Args, Inp>(args: Args) -> Args::Out
754where
755    Args: UnstackAPI<Inp>,
756{
757    Args::unstack(args)
758}
759
760pub fn unstack_f<Args, Inp>(args: Args) -> Result<Args::Out>
761where
762    Args: UnstackAPI<Inp>,
763{
764    Args::unstack_f(args)
765}
766
767impl<'a, T, B, D> UnstackAPI<()> for (TensorView<'a, T, B, D>, isize)
768where
769    T: Clone + Default,
770    D: DimAPI + DimSmallerOneAPI,
771    D::SmallerOne: DimAPI,
772    B: DeviceAPI<T>,
773{
774    type Out = Vec<TensorView<'a, T, B, D::SmallerOne>>;
775
776    fn unstack_f(self) -> Result<Self::Out> {
777        let (tensor, axis) = self;
778
779        // check tensor ndim
780        rstsr_assert!(tensor.ndim() > 0, InvalidLayout, "unstack requires a tensor with ndim > 0.")?;
781
782        // check axis
783        let ndim = tensor.ndim();
784        let axis = if axis < 0 { ndim as isize + axis } else { axis };
785        rstsr_pattern!(axis, 0..ndim as isize, InvalidLayout, "axis out of bounds")?;
786        let axis = axis as usize;
787
788        (0..tensor.layout().shape()[axis])
789            .map(|i| {
790                let view = tensor.view();
791                let (storage, layout) = view.into_raw_parts();
792                let layout = layout.dim_select(axis as isize, i as isize)?;
793                // safety: transmute for lifetime annotation
794                let storage = unsafe { transmute::<Storage<_, T, B>, Storage<_, T, B>>(storage) };
795                unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
796            })
797            .collect()
798    }
799}
800
801impl<'a, R, T, B, D> UnstackAPI<()> for (&'a TensorAny<R, T, B, D>, isize)
802where
803    T: Clone + Default,
804    R: DataAPI<Data = B::Raw>,
805    D: DimAPI + DimSmallerOneAPI,
806    D::SmallerOne: DimAPI,
807    B: DeviceAPI<T>,
808{
809    type Out = Vec<TensorView<'a, T, B, D::SmallerOne>>;
810
811    fn unstack_f(self) -> Result<Self::Out> {
812        let (tensor, axis) = self;
813        UnstackAPI::unstack_f((tensor.view(), axis))
814    }
815}
816
817impl<'a, T, B, D> UnstackAPI<()> for TensorView<'a, T, B, D>
818where
819    T: Clone + Default,
820    D: DimAPI + DimSmallerOneAPI,
821    D::SmallerOne: DimAPI,
822    B: DeviceAPI<T>,
823{
824    type Out = Vec<TensorView<'a, T, B, D::SmallerOne>>;
825
826    fn unstack_f(self) -> Result<Self::Out> {
827        UnstackAPI::unstack_f((self, 0))
828    }
829}
830
831impl<'a, R, T, B, D> UnstackAPI<()> for &'a TensorAny<R, T, B, D>
832where
833    T: Clone + Default,
834    R: DataAPI<Data = B::Raw>,
835    D: DimAPI + DimSmallerOneAPI,
836    D::SmallerOne: DimAPI,
837    B: DeviceAPI<T>,
838{
839    type Out = Vec<TensorView<'a, T, B, D::SmallerOne>>;
840
841    fn unstack_f(self) -> Result<Self::Out> {
842        UnstackAPI::unstack_f((self, 0))
843    }
844}
845
846/* #endregion */
847
848#[cfg(test)]
849mod test {
850    use super::*;
851
852    #[test]
853    fn test_diag() {
854        let a = arange(9).into_shape([3, 3]);
855        let b = diag((&a, 1));
856        println!("{b:}");
857        let c = a.diag();
858        println!("{c:}");
859        let c = arange(3) + 1;
860        let d = diag((&c, -1));
861        println!("{d:}");
862    }
863
864    #[test]
865    fn test_meshgrid() {
866        let a = arange(3);
867        let b = arange(4);
868        let c = meshgrid((&vec![&a, &b], "ij", true));
869        println!("{c:?}");
870        let d = meshgrid((&vec![&a, &b], "xy", true));
871        println!("{d:?}");
872    }
873
874    #[test]
875    fn test_concat() {
876        let a = arange(18).into_shape([2, 3, 3]);
877        let b = arange(24).into_shape([2, 4, 3]);
878        let c = arange(30).into_shape([2, 5, 3]);
879        let d = concat(([a, b, c], -2));
880        println!("{d:?}");
881    }
882
883    #[test]
884    fn test_hstack() {
885        let a = arange(18).into_shape([2, 3, 3]);
886        let b = arange(24).into_shape([2, 4, 3]);
887        let c = arange(30).into_shape([2, 5, 3]);
888        let d = hstack([a, b, c]);
889        println!("{d:?}");
890    }
891
892    #[test]
893    fn test_stack() {
894        let a = arange(8).into_shape([2, 4]);
895        let b = arange(8).into_shape([2, 4]);
896        let c = arange(8).into_shape([2, 4]);
897        let d = stack([&a, &b, &c]);
898        println!("{d:?}");
899        let d = stack(([&a, &b, &c], -1));
900        println!("{d:?}");
901    }
902
903    #[test]
904    fn test_unstack() {
905        let a = arange(24).into_shape([2, 3, 4]);
906        let v = unstack((&a, 2));
907        println!("{v:?}");
908        let v = unstack(a.view());
909        println!("{v:?}");
910    }
911}