Skip to main content

ariadnetor_tensor/dense/
slice_data.rs

1//! Slice, expand, and replace operations for `DenseTensorData<T>`.
2//!
3//! Strip-copy machinery for `slice`, `expand`, and `replace_slice`:
4//! read via `storage().data()` and construct outputs via
5//! [`DenseTensorData::from_raw_parts`].
6
7use num_traits::Zero;
8
9use crate::DenseTensorData;
10use ariadnetor_core::MemoryOrder;
11
12use super::{compute_strides_column_usize, compute_strides_usize};
13
14/// Compute strides (usize) for the given shape and order.
15fn strides_for(shape: &[usize], order: MemoryOrder) -> Vec<usize> {
16    match order {
17        MemoryOrder::RowMajor => compute_strides_usize(shape),
18        MemoryOrder::ColumnMajor => compute_strides_column_usize(shape),
19    }
20}
21
22impl<T> DenseTensorData<T>
23where
24    T: Clone,
25{
26    /// Extract a sub-tensor by specifying a `(start, end)` range for
27    /// each axis (exclusive end).
28    ///
29    /// The flat-data interpretation follows `self.order()`, and the
30    /// output preserves the same order.
31    ///
32    /// # Panics
33    ///
34    /// Panics if `ranges` length doesn't match rank, or any range is
35    /// out of bounds.
36    pub fn slice(&self, ranges: &[(usize, usize)]) -> Self {
37        let shape = self.shape();
38        assert_eq!(
39            ranges.len(),
40            shape.len(),
41            "slice: ranges length {} doesn't match rank {}",
42            ranges.len(),
43            shape.len()
44        );
45        for (i, &(start, end)) in ranges.iter().enumerate() {
46            assert!(
47                start <= end && end <= shape[i],
48                "slice: range ({start}, {end}) out of bounds for axis {i} with size {}",
49                shape[i]
50            );
51        }
52
53        let order = self.order();
54        let new_shape: Vec<usize> = ranges.iter().map(|&(s, e)| e - s).collect();
55        let new_total: usize = new_shape.iter().product();
56        let rank = shape.len();
57
58        if new_total == 0 {
59            return DenseTensorData::from_raw_parts(Vec::new(), new_shape, order);
60        }
61        if rank == 0 {
62            return self.clone();
63        }
64
65        let inner_axis = match order {
66            MemoryOrder::RowMajor => rank - 1,
67            MemoryOrder::ColumnMajor => 0,
68        };
69
70        let src_strides = strides_for(shape, order);
71        let raw = self.storage().data();
72        let strip_len = new_shape[inner_axis];
73        let num_strips = new_total / strip_len.max(1);
74
75        let outer_axes: Vec<usize> = match order {
76            MemoryOrder::RowMajor => (0..rank - 1).collect(),
77            MemoryOrder::ColumnMajor => (1..rank).rev().collect(),
78        };
79
80        let mut data = Vec::with_capacity(new_total);
81        let mut outer_coords = vec![0usize; rank];
82        let strip_src_start: usize = ranges
83            .iter()
84            .zip(&src_strides)
85            .map(|(&(s, _), &st)| s * st)
86            .sum();
87        let mut outer_flat = strip_src_start;
88
89        for _ in 0..num_strips {
90            data.extend_from_slice(&raw[outer_flat..outer_flat + strip_len]);
91
92            for &d in outer_axes.iter().rev() {
93                outer_coords[d] += 1;
94                outer_flat += src_strides[d];
95                if outer_coords[d] < new_shape[d] {
96                    break;
97                }
98                outer_flat -= new_shape[d] * src_strides[d];
99                outer_coords[d] = 0;
100            }
101        }
102
103        DenseTensorData::from_raw_parts(data, new_shape, order)
104    }
105
106    /// Expand tensor by adding zero-padding at the boundaries.
107    ///
108    /// The flat-data interpretation follows `self.order()`, and the
109    /// output preserves the same order.
110    pub fn expand(&self, padding: &[(usize, usize)]) -> Self
111    where
112        T: Zero,
113    {
114        let shape = self.shape();
115        assert_eq!(
116            padding.len(),
117            shape.len(),
118            "expand: padding length {} doesn't match rank {}",
119            padding.len(),
120            shape.len()
121        );
122
123        let order = self.order();
124        let new_shape: Vec<usize> = shape
125            .iter()
126            .zip(padding)
127            .map(|(&s, &(before, after))| s + before + after)
128            .collect();
129        let new_total: usize = new_shape.iter().product();
130        let dst_strides = strides_for(&new_shape, order);
131        let rank = shape.len();
132        let mut data = vec![T::zero(); new_total];
133
134        let src_total = self.len();
135        if src_total == 0 || rank == 0 {
136            if src_total == 1 {
137                data[0] = self.storage().data()[0].clone();
138            }
139            return DenseTensorData::from_raw_parts(data, new_shape, order);
140        }
141
142        let inner_axis = match order {
143            MemoryOrder::RowMajor => rank - 1,
144            MemoryOrder::ColumnMajor => 0,
145        };
146        let no_inner_pad = padding[inner_axis] == (0, 0);
147        let src_strides = strides_for(shape, order);
148
149        if no_inner_pad {
150            let raw = self.storage().data();
151            let strip_len = shape[inner_axis];
152            let outer_axes: Vec<usize> = match order {
153                MemoryOrder::RowMajor => (0..rank - 1).collect(),
154                MemoryOrder::ColumnMajor => (1..rank).rev().collect(),
155            };
156            let num_strips = src_total / strip_len.max(1);
157            let mut src_offset = 0usize;
158            let mut dst_flat: usize = (0..rank).map(|d| padding[d].0 * dst_strides[d]).sum();
159            let mut outer_coords = vec![0usize; rank];
160
161            for _ in 0..num_strips {
162                data[dst_flat..dst_flat + strip_len]
163                    .clone_from_slice(&raw[src_offset..src_offset + strip_len]);
164                src_offset += strip_len;
165                for &d in outer_axes.iter().rev() {
166                    outer_coords[d] += 1;
167                    dst_flat += dst_strides[d];
168                    if outer_coords[d] < shape[d] {
169                        break;
170                    }
171                    dst_flat -= shape[d] * dst_strides[d];
172                    outer_coords[d] = 0;
173                }
174            }
175            return DenseTensorData::from_raw_parts(data, new_shape, order);
176        }
177
178        let raw = self.storage().data();
179        let mut coords = vec![0usize; rank];
180        let axis_order: Vec<usize> = match order {
181            MemoryOrder::RowMajor => (0..rank).collect(),
182            MemoryOrder::ColumnMajor => (0..rank).rev().collect(),
183        };
184        let mut src_flat: usize = 0;
185        let mut dst_flat: usize = (0..rank).map(|d| padding[d].0 * dst_strides[d]).sum();
186
187        for _ in 0..src_total {
188            data[dst_flat] = raw[src_flat].clone();
189            for &d in axis_order.iter().rev() {
190                coords[d] += 1;
191                src_flat += src_strides[d];
192                dst_flat += dst_strides[d];
193                if coords[d] < shape[d] {
194                    break;
195                }
196                src_flat -= shape[d] * src_strides[d];
197                dst_flat -= shape[d] * dst_strides[d];
198                coords[d] = 0;
199            }
200        }
201
202        DenseTensorData::from_raw_parts(data, new_shape, order)
203    }
204
205    /// Write a sub-tensor into this tensor starting at the given
206    /// position (triggers CoW on the storage half if shared).
207    ///
208    /// The flat-data interpretation follows `self.order()`.
209    ///
210    /// # Panics
211    ///
212    /// Panics if `sub.rank()` or `begin.len()` does not match
213    /// `self.rank()`, or any sub-tensor extent exceeds the
214    /// destination's bounds. Also panics if `sub.order()` differs from
215    /// `self.order()` at rank ≥ 2.
216    pub fn replace_slice(&mut self, sub: &Self, begin: &[usize]) {
217        let shape: Vec<usize> = self.shape().to_vec();
218        let sub_shape = sub.shape();
219        assert_eq!(
220            sub_shape.len(),
221            shape.len(),
222            "replace_slice: sub rank {} doesn't match rank {}",
223            sub_shape.len(),
224            shape.len()
225        );
226        assert_eq!(
227            begin.len(),
228            shape.len(),
229            "replace_slice: begin length {} doesn't match rank {}",
230            begin.len(),
231            shape.len()
232        );
233        for (d, (&b, &ss)) in begin.iter().zip(sub_shape).enumerate() {
234            assert!(
235                b + ss <= shape[d],
236                "replace_slice: sub-tensor exceeds boundary on axis {d} ({b} + {ss} > {})",
237                shape[d]
238            );
239        }
240
241        let rank = shape.len();
242        let sub_total = sub.len();
243        if sub_total == 0 {
244            return;
245        }
246
247        if rank == 0 {
248            self.storage_mut().data_mut()[0] = sub.storage().data()[0].clone();
249            return;
250        }
251
252        let order = self.order();
253        if rank >= 2 {
254            assert_eq!(
255                sub.order(),
256                order,
257                "replace_slice: sub.order() ({:?}) must equal self.order() ({:?}) at rank >= 2",
258                sub.order(),
259                order,
260            );
261        }
262
263        let inner_axis = match order {
264            MemoryOrder::RowMajor => rank - 1,
265            MemoryOrder::ColumnMajor => 0,
266        };
267        let self_strides = strides_for(&shape, order);
268        let sub_raw = sub.storage().data();
269        let strip_len = sub_shape[inner_axis];
270        let num_strips = sub_total / strip_len.max(1);
271        let outer_axes: Vec<usize> = match order {
272            MemoryOrder::RowMajor => (0..rank - 1).collect(),
273            MemoryOrder::ColumnMajor => (1..rank).rev().collect(),
274        };
275
276        let dst_buf = self.storage_mut().data_mut();
277        let mut src_offset = 0usize;
278        let mut dst_flat: usize = begin.iter().zip(&self_strides).map(|(&b, &s)| b * s).sum();
279        let mut outer_coords = vec![0usize; rank];
280
281        for _ in 0..num_strips {
282            dst_buf[dst_flat..dst_flat + strip_len]
283                .clone_from_slice(&sub_raw[src_offset..src_offset + strip_len]);
284            src_offset += strip_len;
285
286            for &d in outer_axes.iter().rev() {
287                outer_coords[d] += 1;
288                dst_flat += self_strides[d];
289                if outer_coords[d] < sub_shape[d] {
290                    break;
291                }
292                dst_flat -= sub_shape[d] * self_strides[d];
293                outer_coords[d] = 0;
294            }
295        }
296    }
297}