Skip to main content

ariadnetor_tensor/dense/
multi_tensor.rs

1//! Multi-tensor operations on `DenseTensorData<T>`: concatenate and stack.
2
3use crate::DenseTensorData;
4use ariadnetor_core::MemoryOrder;
5
6impl<T> DenseTensorData<T>
7where
8    T: Clone,
9{
10    /// Concatenate tensors along an existing axis.
11    ///
12    /// All tensors must have the same rank, the same `order()`, and
13    /// matching sizes on all axes except `axis`. The output preserves
14    /// the shared `order()`.
15    pub fn concatenate(tensors: &[&DenseTensorData<T>], axis: usize) -> Self {
16        assert!(!tensors.is_empty(), "concatenate: empty tensor list");
17        let rank = tensors[0].rank();
18        assert!(
19            axis < rank,
20            "concatenate: axis {axis} out of range for rank {rank}"
21        );
22
23        let order = tensors[0].order();
24        let base_shape = tensors[0].shape().to_vec();
25        for (i, t) in tensors.iter().enumerate().skip(1) {
26            assert_eq!(
27                t.rank(),
28                rank,
29                "concatenate: tensor {i} has rank {} but expected {rank}",
30                t.rank()
31            );
32            assert_eq!(
33                t.order(),
34                order,
35                "concatenate: tensor {i} has order {:?} but expected {:?}",
36                t.order(),
37                order,
38            );
39            for (d, (&ts, &bs)) in t.shape().iter().zip(&base_shape).enumerate() {
40                if d != axis {
41                    assert_eq!(
42                        ts, bs,
43                        "concatenate: tensor {i} has size {ts} on axis {d} but expected {bs}",
44                    );
45                }
46            }
47        }
48
49        let mut out_shape: Vec<usize> = base_shape.clone();
50        out_shape[axis] = tensors.iter().map(|t| t.shape()[axis]).sum();
51        let out_total: usize = out_shape.iter().product();
52
53        if out_total == 0 {
54            return DenseTensorData::from_raw_parts(Vec::new(), out_shape, order);
55        }
56
57        let is_outermost = match order {
58            MemoryOrder::RowMajor => axis == 0,
59            MemoryOrder::ColumnMajor => axis == rank - 1,
60        };
61
62        let mut data = Vec::with_capacity(out_total);
63        if is_outermost {
64            for t in tensors {
65                data.extend_from_slice(t.storage().data());
66            }
67        } else {
68            let (strip_len, outer_count) = match order {
69                MemoryOrder::RowMajor => (
70                    base_shape[axis + 1..].iter().product::<usize>(),
71                    base_shape[..axis].iter().product::<usize>(),
72                ),
73                MemoryOrder::ColumnMajor => (
74                    base_shape[..axis].iter().product::<usize>(),
75                    base_shape[axis + 1..].iter().product::<usize>(),
76                ),
77            };
78
79            for outer in 0..outer_count {
80                for t in tensors {
81                    let t_axis_size = t.shape()[axis];
82                    let block_size = t_axis_size * strip_len;
83                    let src_start = outer * block_size;
84                    let src = &t.storage().data()[src_start..src_start + block_size];
85                    data.extend_from_slice(src);
86                }
87            }
88        }
89
90        DenseTensorData::from_raw_parts(data, out_shape, order)
91    }
92
93    /// Stack tensors along a new axis.
94    ///
95    /// All tensors must have the same shape and the same `order()`.
96    /// The output preserves the shared `order()`.
97    pub fn stack(tensors: &[&DenseTensorData<T>], axis: usize) -> Self {
98        assert!(!tensors.is_empty(), "stack: empty tensor list");
99        let base_shape = tensors[0].shape().to_vec();
100        let rank = tensors[0].rank();
101        assert!(
102            axis <= rank,
103            "stack: axis {axis} out of range for rank {rank} (max {rank})"
104        );
105
106        for (i, t) in tensors.iter().enumerate().skip(1) {
107            assert_eq!(
108                t.shape(),
109                base_shape.as_slice(),
110                "stack: tensor {i} has shape {:?} but expected {base_shape:?}",
111                t.shape()
112            );
113        }
114
115        let mut new_shape = Vec::with_capacity(rank + 1);
116        new_shape.extend_from_slice(&base_shape[..axis]);
117        new_shape.push(1);
118        new_shape.extend_from_slice(&base_shape[axis..]);
119
120        let reshaped: Vec<DenseTensorData<T>> = tensors
121            .iter()
122            .map(|t| t.reshape(new_shape.clone()))
123            .collect();
124        let reshaped_refs: Vec<&DenseTensorData<T>> = reshaped.iter().collect();
125
126        Self::concatenate(&reshaped_refs, axis)
127    }
128}