ariadnetor_tensor/dense/
multi_tensor.rs1use crate::DenseTensorData;
4use ariadnetor_core::MemoryOrder;
5
6impl<T> DenseTensorData<T>
7where
8 T: Clone,
9{
10 pub fn concatenate(tensors: &[&DenseTensorData<T>], axis: usize) -> Self {
16 assert!(!tensors.is_empty(), "concatenate: empty tensor list");
17 let rank = tensors[0].rank();
18 assert!(
19 axis < rank,
20 "concatenate: axis {axis} out of range for rank {rank}"
21 );
22
23 let order = tensors[0].order();
24 let base_shape = tensors[0].shape().to_vec();
25 for (i, t) in tensors.iter().enumerate().skip(1) {
26 assert_eq!(
27 t.rank(),
28 rank,
29 "concatenate: tensor {i} has rank {} but expected {rank}",
30 t.rank()
31 );
32 assert_eq!(
33 t.order(),
34 order,
35 "concatenate: tensor {i} has order {:?} but expected {:?}",
36 t.order(),
37 order,
38 );
39 for (d, (&ts, &bs)) in t.shape().iter().zip(&base_shape).enumerate() {
40 if d != axis {
41 assert_eq!(
42 ts, bs,
43 "concatenate: tensor {i} has size {ts} on axis {d} but expected {bs}",
44 );
45 }
46 }
47 }
48
49 let mut out_shape: Vec<usize> = base_shape.clone();
50 out_shape[axis] = tensors.iter().map(|t| t.shape()[axis]).sum();
51 let out_total: usize = out_shape.iter().product();
52
53 if out_total == 0 {
54 return DenseTensorData::from_raw_parts(Vec::new(), out_shape, order);
55 }
56
57 let is_outermost = match order {
58 MemoryOrder::RowMajor => axis == 0,
59 MemoryOrder::ColumnMajor => axis == rank - 1,
60 };
61
62 let mut data = Vec::with_capacity(out_total);
63 if is_outermost {
64 for t in tensors {
65 data.extend_from_slice(t.storage().data());
66 }
67 } else {
68 let (strip_len, outer_count) = match order {
69 MemoryOrder::RowMajor => (
70 base_shape[axis + 1..].iter().product::<usize>(),
71 base_shape[..axis].iter().product::<usize>(),
72 ),
73 MemoryOrder::ColumnMajor => (
74 base_shape[..axis].iter().product::<usize>(),
75 base_shape[axis + 1..].iter().product::<usize>(),
76 ),
77 };
78
79 for outer in 0..outer_count {
80 for t in tensors {
81 let t_axis_size = t.shape()[axis];
82 let block_size = t_axis_size * strip_len;
83 let src_start = outer * block_size;
84 let src = &t.storage().data()[src_start..src_start + block_size];
85 data.extend_from_slice(src);
86 }
87 }
88 }
89
90 DenseTensorData::from_raw_parts(data, out_shape, order)
91 }
92
93 pub fn stack(tensors: &[&DenseTensorData<T>], axis: usize) -> Self {
98 assert!(!tensors.is_empty(), "stack: empty tensor list");
99 let base_shape = tensors[0].shape().to_vec();
100 let rank = tensors[0].rank();
101 assert!(
102 axis <= rank,
103 "stack: axis {axis} out of range for rank {rank} (max {rank})"
104 );
105
106 for (i, t) in tensors.iter().enumerate().skip(1) {
107 assert_eq!(
108 t.shape(),
109 base_shape.as_slice(),
110 "stack: tensor {i} has shape {:?} but expected {base_shape:?}",
111 t.shape()
112 );
113 }
114
115 let mut new_shape = Vec::with_capacity(rank + 1);
116 new_shape.extend_from_slice(&base_shape[..axis]);
117 new_shape.push(1);
118 new_shape.extend_from_slice(&base_shape[axis..]);
119
120 let reshaped: Vec<DenseTensorData<T>> = tensors
121 .iter()
122 .map(|t| t.reshape(new_shape.clone()))
123 .collect();
124 let reshaped_refs: Vec<&DenseTensorData<T>> = reshaped.iter().collect();
125
126 Self::concatenate(&reshaped_refs, axis)
127 }
128}