1use core::mem::transmute;
9
10use crate::prelude_dev::*;
11
12pub trait DiagAPI<Inp> {
15 type Out;
16
17 fn diag_f(self) -> Result<Self::Out>;
18 fn diag(self) -> Self::Out
19 where
20 Self: Sized,
21 {
22 Self::diag_f(self).rstsr_unwrap()
23 }
24}
25
26pub fn diag<Args, Inp>(param: Args) -> Args::Out
35where
36 Args: DiagAPI<Inp>,
37{
38 Args::diag(param)
39}
40
41pub fn diag_f<Args, Inp>(param: Args) -> Result<Args::Out>
42where
43 Args: DiagAPI<Inp>,
44{
45 Args::diag_f(param)
46}
47
48impl<R, T, B, D> DiagAPI<()> for (&TensorAny<R, T, B, D>, isize)
49where
50 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
51 T: Clone + Default,
52 D: DimAPI,
53 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
54{
55 type Out = Tensor<T, B, IxD>;
56
57 fn diag_f(self) -> Result<Self::Out> {
58 let (tensor, offset) = self;
59 if tensor.ndim() == 1 {
60 let layout_diag = tensor.layout().to_dim::<Ix1>()?;
61 let n_row = tensor.size() + offset.unsigned_abs();
62 let mut result = full_f(([n_row, n_row], T::default(), tensor.device()))?;
63 let layout_result = result.layout().diagonal(Some(offset), Some(0), Some(1))?;
64 let device = tensor.device();
65 device.assign(result.raw_mut(), &layout_result.to_dim()?, tensor.raw(), &layout_diag)?;
66 return Ok(result);
67 } else if tensor.ndim() == 2 {
68 let layout = tensor.layout().to_dim::<Ix2>()?;
69 let layout_diag = layout.diagonal(Some(offset), Some(0), Some(1))?;
70 let size = layout_diag.size();
71 let device = tensor.device();
72 let mut result = unsafe { empty_f(([size], device))? };
73 let layout_result = result.layout().to_dim()?;
74 device.assign(result.raw_mut(), &layout_result, tensor.raw(), &layout_diag)?;
75 return Ok(result);
76 } else {
77 return rstsr_raise!(InvalidLayout, "diag only support 1-D or 2-D tensor.");
78 }
79 }
80}
81
82impl<R, T, B, D> DiagAPI<()> for &TensorAny<R, T, B, D>
83where
84 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
85 T: Clone + Default,
86 D: DimAPI,
87 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
88{
89 type Out = Tensor<T, B, IxD>;
90
91 fn diag_f(self) -> Result<Self::Out> {
92 return diag_f((self, 0));
93 }
94}
95
96pub trait MeshgridAPI<Inp> {
101 type Out;
102
103 fn meshgrid_f(self) -> Result<Self::Out>;
104 fn meshgrid(self) -> Self::Out
105 where
106 Self: Sized,
107 {
108 Self::meshgrid_f(self).rstsr_unwrap()
109 }
110}
111
112pub fn meshgrid<Args, Inp>(args: Args) -> Args::Out
118where
119 Args: MeshgridAPI<Inp>,
120{
121 Args::meshgrid(args)
122}
123
124pub fn meshgrid_f<Args, Inp>(args: Args) -> Result<Args::Out>
125where
126 Args: MeshgridAPI<Inp>,
127{
128 Args::meshgrid_f(args)
129}
130
131impl<R, T, B, D> MeshgridAPI<()> for (Vec<&TensorAny<R, T, B, D>>, &str, bool)
132where
133 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataCloneAPI,
134 T: Clone,
135 D: DimAPI,
136 B: DeviceAPI<T>
137 + DeviceRawAPI<MaybeUninit<T>>
138 + DeviceCreationAnyAPI<T>
139 + OpAssignAPI<T, IxD>
140 + OpAssignArbitaryAPI<T, IxD, IxD>,
141 <B as DeviceRawAPI<T>>::Raw: Clone,
142{
143 type Out = Vec<Tensor<T, B, IxD>>;
144
145 fn meshgrid_f(self) -> Result<Self::Out> {
146 let (tensors, indexing, copy) = self;
147
148 match indexing {
149 "ij" | "xy" => (),
150 _ => rstsr_raise!(InvalidValue, "indexing must be 'ij' or 'xy'.")?,
151 }
152
153 if tensors.is_empty() {
155 return Ok(vec![]);
156 } else if tensors.len() == 1 {
157 let tensor = tensors[0];
158 rstsr_assert_eq!(tensor.ndim(), 1, InvalidLayout, "meshgrid only support 1-D tensor.")?;
159 return Ok(vec![tensor.view().into_dim().into_owned()]);
160 }
161
162 let device = tensors[0].device();
166 tensors.iter().try_for_each(|tensor| -> Result<()> {
167 rstsr_assert_eq!(tensor.ndim(), 1, InvalidLayout, "meshgrid only support 1-D tensor.")?;
168 rstsr_assert!(
169 tensor.device().same_device(device),
170 DeviceMismatch,
171 "All tensors must be on the same device."
172 )?;
173 Ok(())
174 })?;
175
176 let ndim = tensors.len();
177 let s0 = vec![1isize; ndim];
178
179 let tensors = tensors
181 .iter()
182 .enumerate()
183 .map(|(i, tensor)| {
184 let mut shape_new = s0.clone();
185 if indexing == "xy" && i == 0 {
186 shape_new[1] = -1;
188 } else if indexing == "xy" && i == 1 {
189 shape_new[0] = -1;
191 } else {
192 shape_new[i] = -1;
194 }
195 tensor.view().into_dim::<IxD>().into_shape_f(shape_new)
196 })
197 .collect::<Result<Vec<_>>>()?;
198 let tensors = broadcast_arrays_f(tensors)?;
200
201 if !copy {
202 Ok(tensors)
203 } else {
204 tensors.into_iter().map(|t| t.into_contig_f(device.default_order())).collect()
205 }
206 }
207}
208
209#[duplicate_item(
211 ImplType ImplStruct tuple_args tuple_internal ;
212 [ ] [(&Vec<&TensorAny<R, T, B, D>>, &str, bool)] [(tensors, indexing, copy)] [(tensors.to_vec(), indexing, copy)];
213 [const N: usize] [([&TensorAny<R, T, B, D>; N] , &str, bool)] [(tensors, indexing, copy)] [(tensors.to_vec(), indexing, copy)];
214 [ ] [(Vec<&TensorAny<R, T, B, D>> , &str, )] [(tensors, indexing, )] [(tensors.to_vec(), indexing, true)];
215 [ ] [(&Vec<&TensorAny<R, T, B, D>>, &str, )] [(tensors, indexing, )] [(tensors.to_vec(), indexing, true)];
216 [const N: usize] [([&TensorAny<R, T, B, D>; N] , &str, )] [(tensors, indexing, )] [(tensors.to_vec(), indexing, true)];
217 [ ] [(Vec<&TensorAny<R, T, B, D>> , bool)] [(tensors, copy)] [(tensors.to_vec(), "xy" , copy)];
218 [ ] [(&Vec<&TensorAny<R, T, B, D>>, bool)] [(tensors, copy)] [(tensors.to_vec(), "xy" , copy)];
219 [const N: usize] [([&TensorAny<R, T, B, D>; N] , bool)] [(tensors, copy)] [(tensors.to_vec(), "xy" , copy)];
220 [ ] [ Vec<&TensorAny<R, T, B, D>> ] [ tensors ] [(tensors.to_vec(), "xy" , true)];
221 [ ] [ &Vec<&TensorAny<R, T, B, D>> ] [ tensors ] [(tensors.to_vec(), "xy" , true)];
222 [const N: usize] [ [&TensorAny<R, T, B, D>; N] ] [ tensors ] [(tensors.to_vec(), "xy" , true)];
223)]
224impl<R, T, B, D, ImplType> MeshgridAPI<()> for ImplStruct
225where
226 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataCloneAPI,
227 T: Clone,
228 D: DimAPI,
229 B: DeviceAPI<T>
230 + DeviceRawAPI<MaybeUninit<T>>
231 + DeviceCreationAnyAPI<T>
232 + OpAssignAPI<T, IxD>
233 + OpAssignArbitaryAPI<T, IxD, IxD>,
234 <B as DeviceRawAPI<T>>::Raw: Clone,
235{
236 type Out = Vec<Tensor<T, B, IxD>>;
237
238 fn meshgrid_f(self) -> Result<Self::Out> {
239 let tuple_args = self;
240 let (tensors, indexing, copy) = tuple_internal;
241 MeshgridAPI::meshgrid_f((tensors, indexing, copy))
242 }
243}
244
245#[duplicate_item(
247 ImplType ImplStruct tuple_args tuple_internal ;
248 [ ] [(Vec<TensorAny<R, T, B, D>> , &str, bool)] [(tensors, indexing, copy)] [(indexing, copy)];
249 [ ] [(&Vec<TensorAny<R, T, B, D>>, &str, bool)] [(tensors, indexing, copy)] [(indexing, copy)];
250 [const N: usize] [([TensorAny<R, T, B, D>; N] , &str, bool)] [(tensors, indexing, copy)] [(indexing, copy)];
251 [ ] [(Vec<TensorAny<R, T, B, D>> , &str, )] [(tensors, indexing, )] [(indexing, true)];
252 [ ] [(&Vec<TensorAny<R, T, B, D>>, &str, )] [(tensors, indexing, )] [(indexing, true)];
253 [const N: usize] [([TensorAny<R, T, B, D>; N] , &str, )] [(tensors, indexing, )] [(indexing, true)];
254 [ ] [(Vec<TensorAny<R, T, B, D>> , bool)] [(tensors, copy)] [("xy" , copy)];
255 [ ] [(&Vec<TensorAny<R, T, B, D>>, bool)] [(tensors, copy)] [("xy" , copy)];
256 [const N: usize] [([TensorAny<R, T, B, D>; N] , bool)] [(tensors, copy)] [("xy" , copy)];
257 [ ] [ Vec<TensorAny<R, T, B, D>> ] [ tensors ] [("xy" , true)];
258 [ ] [ &Vec<TensorAny<R, T, B, D>> ] [ tensors ] [("xy" , true)];
259 [const N: usize] [ [TensorAny<R, T, B, D>; N] ] [ tensors ] [("xy" , true)];
260)]
261impl<R, T, B, D, ImplType> MeshgridAPI<()> for ImplStruct
262where
263 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw> + DataCloneAPI,
264 T: Clone,
265 D: DimAPI,
266 B: DeviceAPI<T>
267 + DeviceRawAPI<MaybeUninit<T>>
268 + DeviceCreationAnyAPI<T>
269 + OpAssignAPI<T, IxD>
270 + OpAssignArbitaryAPI<T, IxD, IxD>,
271 <B as DeviceRawAPI<T>>::Raw: Clone,
272{
273 type Out = Vec<Tensor<T, B, IxD>>;
274
275 fn meshgrid_f(self) -> Result<Self::Out> {
276 let tuple_args = self;
277 let (indexing, copy) = tuple_internal;
278 let tensors = tensors.iter().collect::<Vec<_>>();
279 MeshgridAPI::meshgrid_f((tensors, indexing, copy))
280 }
281}
282
283pub trait ConcatAPI<Inp> {
288 type Out;
289
290 fn concat_f(self) -> Result<Self::Out>;
291 fn concat(self) -> Self::Out
292 where
293 Self: Sized,
294 {
295 Self::concat_f(self).rstsr_unwrap()
296 }
297}
298
299pub fn concat<Args, Inp>(args: Args) -> Args::Out
305where
306 Args: ConcatAPI<Inp>,
307{
308 Args::concat(args)
309}
310
311pub fn concat_f<Args, Inp>(args: Args) -> Result<Args::Out>
312where
313 Args: ConcatAPI<Inp>,
314{
315 Args::concat_f(args)
316}
317
318pub use concat as concatenate;
319pub use concat_f as concatenate_f;
320
321impl<R, T, B, D> ConcatAPI<()> for (Vec<TensorAny<R, T, B, D>>, isize)
322where
323 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
324 T: Clone + Default,
325 D: DimAPI,
326 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
327{
328 type Out = Tensor<T, B, IxD>;
329
330 fn concat_f(self) -> Result<Self::Out> {
331 let (tensors, axis) = self;
332
333 rstsr_assert!(!tensors.is_empty(), InvalidValue, "concat requires at least one tensor.")?;
335
336 let device = tensors[0].device().clone();
338 let ndim = tensors[0].ndim();
339
340 rstsr_assert!(ndim > 0, InvalidLayout, "All tensors must have ndim > 0 in concat.")?;
341 tensors.iter().try_for_each(|tensor| -> Result<()> {
342 rstsr_assert_eq!(tensor.ndim(), ndim, InvalidLayout, "All tensors must have the same ndim.")?;
343 rstsr_assert!(
344 tensor.device().same_device(&device),
345 DeviceMismatch,
346 "All tensors must be on the same device."
347 )?;
348 Ok(())
349 })?;
350
351 let axis = if axis < 0 { ndim as isize + axis } else { axis };
353 rstsr_pattern!(axis, 0..ndim as isize, InvalidLayout, "axis out of bounds")?;
354 let axis = axis as usize;
355
356 let mut new_axis_size = 0;
359 let mut shape_other = tensors[0].shape().as_ref().to_vec();
360 shape_other.remove(axis);
361 for tensor in &tensors {
362 let mut shape_other_i = tensor.shape().as_ref().to_vec();
363 new_axis_size += shape_other_i.remove(axis);
364 rstsr_assert_eq!(
365 shape_other_i,
366 shape_other,
367 InvalidLayout,
368 "All tensors must have the same shape except for the concatenation axis."
369 )?;
370 }
371 shape_other.insert(axis, new_axis_size);
372 let new_shape = shape_other;
373
374 let mut result = unsafe { empty_f((new_shape, &device))? };
376
377 let mut offset = 0;
379 for tensor in tensors {
380 let layout = tensor.layout().to_dim::<IxD>()?;
381 let axis_size = tensor.shape()[axis];
382 let layout_result = result.layout().dim_narrow(axis as isize, slice!(offset, offset + axis_size))?;
383 device.assign(result.raw_mut(), &layout_result, tensor.raw(), &layout)?;
384 offset += axis_size;
385 }
386
387 Ok(result)
388 }
389}
390
391#[duplicate_item(
392 ImplType ImplStruct ;
393 [ ] [(&Vec<TensorAny<R, T, B, D>> , isize)];
394 [const N: usize] [([TensorAny<R, T, B, D>; N] , isize)];
395 [ ] [(Vec<TensorAny<R, T, B, D>> , usize)];
396 [ ] [(&Vec<TensorAny<R, T, B, D>> , usize)];
397 [const N: usize] [([TensorAny<R, T, B, D>; N] , usize)];
398 [ ] [(Vec<TensorAny<R, T, B, D>> , i32 )];
399 [ ] [(&Vec<TensorAny<R, T, B, D>> , i32 )];
400 [const N: usize] [([TensorAny<R, T, B, D>; N] , i32 )];
401 [ ] [(Vec<&TensorAny<R, T, B, D>> , isize)];
402 [ ] [(&Vec<&TensorAny<R, T, B, D>>, isize)];
403 [const N: usize] [([&TensorAny<R, T, B, D>; N] , isize)];
404 [ ] [(Vec<&TensorAny<R, T, B, D>> , usize)];
405 [ ] [(&Vec<&TensorAny<R, T, B, D>>, usize)];
406 [const N: usize] [([&TensorAny<R, T, B, D>; N] , usize)];
407 [ ] [(Vec<&TensorAny<R, T, B, D>> , i32 )];
408 [ ] [(&Vec<&TensorAny<R, T, B, D>>, i32 )];
409 [const N: usize] [([&TensorAny<R, T, B, D>; N] , i32 )];
410)]
411impl<R, T, B, D, ImplType> ConcatAPI<()> for ImplStruct
412where
413 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
414 T: Clone + Default,
415 D: DimAPI,
416 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
417{
418 type Out = Tensor<T, B, IxD>;
419
420 fn concat_f(self) -> Result<Self::Out> {
421 let (tensors, axis) = self;
422 #[allow(clippy::unnecessary_cast)]
423 let axis = axis as isize;
424 let tensors = tensors.iter().map(|t| t.view()).collect::<Vec<_>>();
425 ConcatAPI::concat_f((tensors, axis))
426 }
427}
428
429#[duplicate_item(
430 ImplType ImplStruct ;
431 [ ] [Vec<TensorAny<R, T, B, D>> ];
432 [ ] [&Vec<TensorAny<R, T, B, D>> ];
433 [const N: usize] [[TensorAny<R, T, B, D>; N] ];
434 [ ] [Vec<&TensorAny<R, T, B, D>> ];
435 [ ] [&Vec<&TensorAny<R, T, B, D>>];
436 [const N: usize] [[&TensorAny<R, T, B, D>; N] ];
437)]
438impl<R, T, B, D, ImplType> ConcatAPI<()> for ImplStruct
439where
440 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
441 T: Clone + Default,
442 D: DimAPI,
443 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
444{
445 type Out = Tensor<T, B, IxD>;
446
447 fn concat_f(self) -> Result<Self::Out> {
448 let tensors = self;
449 #[allow(clippy::unnecessary_cast)]
450 let axis = 0;
451 let tensors = tensors.iter().map(|t| t.view()).collect::<Vec<_>>();
452 ConcatAPI::concat_f((tensors, axis))
453 }
454}
455
456pub trait HStackAPI<Inp> {
461 type Out;
462
463 fn hstack_f(self) -> Result<Self::Out>;
464 fn hstack(self) -> Self::Out
465 where
466 Self: Sized,
467 {
468 Self::hstack_f(self).rstsr_unwrap()
469 }
470}
471
472pub fn hstack<Args, Inp>(args: Args) -> Args::Out
478where
479 Args: HStackAPI<Inp>,
480{
481 Args::hstack(args)
482}
483
484pub fn hstack_f<Args, Inp>(args: Args) -> Result<Args::Out>
485where
486 Args: HStackAPI<Inp>,
487{
488 Args::hstack_f(args)
489}
490
491#[duplicate_item(
492 ImplType ImplStruct ;
493 [ ] [Vec<TensorAny<R, T, B, D>> ];
494 [ ] [&Vec<TensorAny<R, T, B, D>> ];
495 [const N: usize] [[TensorAny<R, T, B, D>; N] ];
496 [ ] [Vec<&TensorAny<R, T, B, D>> ];
497 [ ] [&Vec<&TensorAny<R, T, B, D>>];
498 [const N: usize] [[&TensorAny<R, T, B, D>; N] ];
499)]
500impl<R, T, B, D, ImplType> HStackAPI<()> for ImplStruct
501where
502 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
503 T: Clone + Default,
504 D: DimAPI,
505 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
506{
507 type Out = Tensor<T, B, IxD>;
508
509 fn hstack_f(self) -> Result<Self::Out> {
510 let tensors = self;
511
512 if tensors.is_empty() {
513 return rstsr_raise!(InvalidValue, "hstack requires at least one tensor.");
514 }
515
516 if tensors[0].ndim() == 1 {
517 ConcatAPI::concat_f((tensors, 0))
518 } else {
519 ConcatAPI::concat_f((tensors, 1))
520 }
521 }
522}
523
524pub trait VStackAPI<Inp> {
529 type Out;
530
531 fn vstack_f(self) -> Result<Self::Out>;
532 fn vstack(self) -> Self::Out
533 where
534 Self: Sized,
535 {
536 Self::vstack_f(self).rstsr_unwrap()
537 }
538}
539
540pub fn vstack<Args, Inp>(args: Args) -> Args::Out
546where
547 Args: VStackAPI<Inp>,
548{
549 Args::vstack(args)
550}
551
552pub fn vstack_f<Args, Inp>(args: Args) -> Result<Args::Out>
553where
554 Args: VStackAPI<Inp>,
555{
556 Args::vstack_f(args)
557}
558
559#[duplicate_item(
560 ImplType ImplStruct ;
561 [ ] [Vec<TensorAny<R, T, B, D>> ];
562 [ ] [&Vec<TensorAny<R, T, B, D>> ];
563 [const N: usize] [[TensorAny<R, T, B, D>; N] ];
564 [ ] [Vec<&TensorAny<R, T, B, D>> ];
565 [ ] [&Vec<&TensorAny<R, T, B, D>>];
566 [const N: usize] [[&TensorAny<R, T, B, D>; N] ];
567)]
568impl<R, T, B, D, ImplType> VStackAPI<()> for ImplStruct
569where
570 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
571 T: Clone + Default,
572 D: DimAPI,
573 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
574{
575 type Out = Tensor<T, B, IxD>;
576
577 fn vstack_f(self) -> Result<Self::Out> {
578 let tensors = self;
579
580 if tensors.is_empty() {
581 return rstsr_raise!(InvalidValue, "vstack requires at least one tensor.");
582 }
583
584 ConcatAPI::concat_f((tensors, 0))
585 }
586}
587
588pub trait StackAPI<Inp> {
593 type Out;
594
595 fn stack_f(self) -> Result<Self::Out>;
596 fn stack(self) -> Self::Out
597 where
598 Self: Sized,
599 {
600 Self::stack_f(self).rstsr_unwrap()
601 }
602}
603
604pub fn stack<Args, Inp>(args: Args) -> Args::Out
610where
611 Args: StackAPI<Inp>,
612{
613 Args::stack(args)
614}
615
616pub fn stack_f<Args, Inp>(args: Args) -> Result<Args::Out>
617where
618 Args: StackAPI<Inp>,
619{
620 Args::stack_f(args)
621}
622
623impl<R, T, B, D> StackAPI<()> for (Vec<TensorAny<R, T, B, D>>, isize)
624where
625 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
626 T: Clone + Default,
627 D: DimAPI,
628 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
629{
630 type Out = Tensor<T, B, IxD>;
631
632 fn stack_f(self) -> Result<Self::Out> {
633 let (tensors, axis) = self;
634
635 rstsr_assert!(!tensors.is_empty(), InvalidValue, "stack requires at least one tensor.")?;
637
638 let device = tensors[0].device().clone();
640 let ndim = tensors[0].ndim();
641 let shape_orig = tensors[0].shape();
642
643 rstsr_assert!(ndim > 0, InvalidLayout, "All tensors must have ndim > 0 in stack.")?;
644 tensors.iter().try_for_each(|tensor| -> Result<()> {
645 rstsr_assert_eq!(tensor.shape(), shape_orig, InvalidLayout, "All tensors must have the same shape.")?;
646 rstsr_assert!(
647 tensor.device().same_device(&device),
648 DeviceMismatch,
649 "All tensors must be on the same device."
650 )?;
651 Ok(())
652 })?;
653
654 let axis = if axis < 0 { ndim as isize + axis + 1 } else { axis };
656 rstsr_pattern!(axis, 0..=ndim as isize, InvalidLayout, "axis out of bounds")?;
657 let axis = axis as usize;
658
659 let tensors = tensors.into_iter().map(|tensor| tensor.into_expand_dims_f(axis)).collect::<Result<Vec<_>>>()?;
661
662 ConcatAPI::concat_f((tensors, axis as isize))
664 }
665}
666
667#[duplicate_item(
668 ImplType ImplStruct ;
669 [ ] [(&Vec<TensorAny<R, T, B, D>> , isize)];
670 [const N: usize] [([TensorAny<R, T, B, D>; N] , isize)];
671 [ ] [(Vec<TensorAny<R, T, B, D>> , usize)];
672 [ ] [(&Vec<TensorAny<R, T, B, D>> , usize)];
673 [const N: usize] [([TensorAny<R, T, B, D>; N] , usize)];
674 [ ] [(Vec<TensorAny<R, T, B, D>> , i32 )];
675 [ ] [(&Vec<TensorAny<R, T, B, D>> , i32 )];
676 [const N: usize] [([TensorAny<R, T, B, D>; N] , i32 )];
677 [ ] [(Vec<&TensorAny<R, T, B, D>> , isize)];
678 [ ] [(&Vec<&TensorAny<R, T, B, D>>, isize)];
679 [const N: usize] [([&TensorAny<R, T, B, D>; N] , isize)];
680 [ ] [(Vec<&TensorAny<R, T, B, D>> , usize)];
681 [ ] [(&Vec<&TensorAny<R, T, B, D>>, usize)];
682 [const N: usize] [([&TensorAny<R, T, B, D>; N] , usize)];
683 [ ] [(Vec<&TensorAny<R, T, B, D>> , i32 )];
684 [ ] [(&Vec<&TensorAny<R, T, B, D>>, i32 )];
685 [const N: usize] [([&TensorAny<R, T, B, D>; N] , i32 )];
686)]
687impl<R, T, B, D, ImplType> StackAPI<()> for ImplStruct
688where
689 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
690 T: Clone + Default,
691 D: DimAPI,
692 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
693{
694 type Out = Tensor<T, B, IxD>;
695
696 fn stack_f(self) -> Result<Self::Out> {
697 let (tensors, axis) = self;
698 #[allow(clippy::unnecessary_cast)]
699 let axis = axis as isize;
700 let tensors = tensors.iter().map(|t| t.view()).collect::<Vec<_>>();
701 StackAPI::stack_f((tensors, axis))
702 }
703}
704
705#[duplicate_item(
706 ImplType ImplStruct ;
707 [ ] [Vec<TensorAny<R, T, B, D>> ];
708 [ ] [&Vec<TensorAny<R, T, B, D>> ];
709 [const N: usize] [[TensorAny<R, T, B, D>; N] ];
710 [ ] [Vec<&TensorAny<R, T, B, D>> ];
711 [ ] [&Vec<&TensorAny<R, T, B, D>>];
712 [const N: usize] [[&TensorAny<R, T, B, D>; N] ];
713)]
714impl<R, T, B, D, ImplType> StackAPI<()> for ImplStruct
715where
716 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
717 T: Clone + Default,
718 D: DimAPI,
719 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, IxD>,
720{
721 type Out = Tensor<T, B, IxD>;
722
723 fn stack_f(self) -> Result<Self::Out> {
724 let tensors = self;
725 #[allow(clippy::unnecessary_cast)]
726 let axis = 0;
727 let tensors = tensors.iter().map(|t| t.view()).collect::<Vec<_>>();
728 StackAPI::stack_f((tensors, axis))
729 }
730}
731
732pub trait UnstackAPI<Inp> {
737 type Out;
738
739 fn unstack_f(self) -> Result<Self::Out>;
740 fn unstack(self) -> Self::Out
741 where
742 Self: Sized,
743 {
744 Self::unstack_f(self).rstsr_unwrap()
745 }
746}
747
748pub fn unstack<Args, Inp>(args: Args) -> Args::Out
754where
755 Args: UnstackAPI<Inp>,
756{
757 Args::unstack(args)
758}
759
760pub fn unstack_f<Args, Inp>(args: Args) -> Result<Args::Out>
761where
762 Args: UnstackAPI<Inp>,
763{
764 Args::unstack_f(args)
765}
766
767impl<'a, T, B, D> UnstackAPI<()> for (TensorView<'a, T, B, D>, isize)
768where
769 T: Clone + Default,
770 D: DimAPI + DimSmallerOneAPI,
771 D::SmallerOne: DimAPI,
772 B: DeviceAPI<T>,
773{
774 type Out = Vec<TensorView<'a, T, B, D::SmallerOne>>;
775
776 fn unstack_f(self) -> Result<Self::Out> {
777 let (tensor, axis) = self;
778
779 rstsr_assert!(tensor.ndim() > 0, InvalidLayout, "unstack requires a tensor with ndim > 0.")?;
781
782 let ndim = tensor.ndim();
784 let axis = if axis < 0 { ndim as isize + axis } else { axis };
785 rstsr_pattern!(axis, 0..ndim as isize, InvalidLayout, "axis out of bounds")?;
786 let axis = axis as usize;
787
788 (0..tensor.layout().shape()[axis])
789 .map(|i| {
790 let view = tensor.view();
791 let (storage, layout) = view.into_raw_parts();
792 let layout = layout.dim_select(axis as isize, i as isize)?;
793 let storage = unsafe { transmute::<Storage<_, T, B>, Storage<_, T, B>>(storage) };
795 unsafe { Ok(TensorBase::new_unchecked(storage, layout)) }
796 })
797 .collect()
798 }
799}
800
801impl<'a, R, T, B, D> UnstackAPI<()> for (&'a TensorAny<R, T, B, D>, isize)
802where
803 T: Clone + Default,
804 R: DataAPI<Data = B::Raw>,
805 D: DimAPI + DimSmallerOneAPI,
806 D::SmallerOne: DimAPI,
807 B: DeviceAPI<T>,
808{
809 type Out = Vec<TensorView<'a, T, B, D::SmallerOne>>;
810
811 fn unstack_f(self) -> Result<Self::Out> {
812 let (tensor, axis) = self;
813 UnstackAPI::unstack_f((tensor.view(), axis))
814 }
815}
816
817impl<'a, T, B, D> UnstackAPI<()> for TensorView<'a, T, B, D>
818where
819 T: Clone + Default,
820 D: DimAPI + DimSmallerOneAPI,
821 D::SmallerOne: DimAPI,
822 B: DeviceAPI<T>,
823{
824 type Out = Vec<TensorView<'a, T, B, D::SmallerOne>>;
825
826 fn unstack_f(self) -> Result<Self::Out> {
827 UnstackAPI::unstack_f((self, 0))
828 }
829}
830
831impl<'a, R, T, B, D> UnstackAPI<()> for &'a TensorAny<R, T, B, D>
832where
833 T: Clone + Default,
834 R: DataAPI<Data = B::Raw>,
835 D: DimAPI + DimSmallerOneAPI,
836 D::SmallerOne: DimAPI,
837 B: DeviceAPI<T>,
838{
839 type Out = Vec<TensorView<'a, T, B, D::SmallerOne>>;
840
841 fn unstack_f(self) -> Result<Self::Out> {
842 UnstackAPI::unstack_f((self, 0))
843 }
844}
845
846#[cfg(test)]
849mod test {
850 use super::*;
851
852 #[test]
853 fn test_diag() {
854 let a = arange(9).into_shape([3, 3]);
855 let b = diag((&a, 1));
856 println!("{b:}");
857 let c = a.diag();
858 println!("{c:}");
859 let c = arange(3) + 1;
860 let d = diag((&c, -1));
861 println!("{d:}");
862 }
863
864 #[test]
865 fn test_meshgrid() {
866 let a = arange(3);
867 let b = arange(4);
868 let c = meshgrid((&vec![&a, &b], "ij", true));
869 println!("{c:?}");
870 let d = meshgrid((&vec![&a, &b], "xy", true));
871 println!("{d:?}");
872 }
873
874 #[test]
875 fn test_concat() {
876 let a = arange(18).into_shape([2, 3, 3]);
877 let b = arange(24).into_shape([2, 4, 3]);
878 let c = arange(30).into_shape([2, 5, 3]);
879 let d = concat(([a, b, c], -2));
880 println!("{d:?}");
881 }
882
883 #[test]
884 fn test_hstack() {
885 let a = arange(18).into_shape([2, 3, 3]);
886 let b = arange(24).into_shape([2, 4, 3]);
887 let c = arange(30).into_shape([2, 5, 3]);
888 let d = hstack([a, b, c]);
889 println!("{d:?}");
890 }
891
892 #[test]
893 fn test_stack() {
894 let a = arange(8).into_shape([2, 4]);
895 let b = arange(8).into_shape([2, 4]);
896 let c = arange(8).into_shape([2, 4]);
897 let d = stack([&a, &b, &c]);
898 println!("{d:?}");
899 let d = stack(([&a, &b, &c], -1));
900 println!("{d:?}");
901 }
902
903 #[test]
904 fn test_unstack() {
905 let a = arange(24).into_shape([2, 3, 4]);
906 let v = unstack((&a, 2));
907 println!("{v:?}");
908 let v = unstack(a.view());
909 println!("{v:?}");
910 }
911}