ariadnetor_tensor/dense/
slice_data.rs1use num_traits::Zero;
8
9use crate::DenseTensorData;
10use ariadnetor_core::MemoryOrder;
11
12use super::{compute_strides_column_usize, compute_strides_usize};
13
14fn strides_for(shape: &[usize], order: MemoryOrder) -> Vec<usize> {
16 match order {
17 MemoryOrder::RowMajor => compute_strides_usize(shape),
18 MemoryOrder::ColumnMajor => compute_strides_column_usize(shape),
19 }
20}
21
22impl<T> DenseTensorData<T>
23where
24 T: Clone,
25{
26 pub fn slice(&self, ranges: &[(usize, usize)]) -> Self {
37 let shape = self.shape();
38 assert_eq!(
39 ranges.len(),
40 shape.len(),
41 "slice: ranges length {} doesn't match rank {}",
42 ranges.len(),
43 shape.len()
44 );
45 for (i, &(start, end)) in ranges.iter().enumerate() {
46 assert!(
47 start <= end && end <= shape[i],
48 "slice: range ({start}, {end}) out of bounds for axis {i} with size {}",
49 shape[i]
50 );
51 }
52
53 let order = self.order();
54 let new_shape: Vec<usize> = ranges.iter().map(|&(s, e)| e - s).collect();
55 let new_total: usize = new_shape.iter().product();
56 let rank = shape.len();
57
58 if new_total == 0 {
59 return DenseTensorData::from_raw_parts(Vec::new(), new_shape, order);
60 }
61 if rank == 0 {
62 return self.clone();
63 }
64
65 let inner_axis = match order {
66 MemoryOrder::RowMajor => rank - 1,
67 MemoryOrder::ColumnMajor => 0,
68 };
69
70 let src_strides = strides_for(shape, order);
71 let raw = self.storage().data();
72 let strip_len = new_shape[inner_axis];
73 let num_strips = new_total / strip_len.max(1);
74
75 let outer_axes: Vec<usize> = match order {
76 MemoryOrder::RowMajor => (0..rank - 1).collect(),
77 MemoryOrder::ColumnMajor => (1..rank).rev().collect(),
78 };
79
80 let mut data = Vec::with_capacity(new_total);
81 let mut outer_coords = vec![0usize; rank];
82 let strip_src_start: usize = ranges
83 .iter()
84 .zip(&src_strides)
85 .map(|(&(s, _), &st)| s * st)
86 .sum();
87 let mut outer_flat = strip_src_start;
88
89 for _ in 0..num_strips {
90 data.extend_from_slice(&raw[outer_flat..outer_flat + strip_len]);
91
92 for &d in outer_axes.iter().rev() {
93 outer_coords[d] += 1;
94 outer_flat += src_strides[d];
95 if outer_coords[d] < new_shape[d] {
96 break;
97 }
98 outer_flat -= new_shape[d] * src_strides[d];
99 outer_coords[d] = 0;
100 }
101 }
102
103 DenseTensorData::from_raw_parts(data, new_shape, order)
104 }
105
106 pub fn expand(&self, padding: &[(usize, usize)]) -> Self
111 where
112 T: Zero,
113 {
114 let shape = self.shape();
115 assert_eq!(
116 padding.len(),
117 shape.len(),
118 "expand: padding length {} doesn't match rank {}",
119 padding.len(),
120 shape.len()
121 );
122
123 let order = self.order();
124 let new_shape: Vec<usize> = shape
125 .iter()
126 .zip(padding)
127 .map(|(&s, &(before, after))| s + before + after)
128 .collect();
129 let new_total: usize = new_shape.iter().product();
130 let dst_strides = strides_for(&new_shape, order);
131 let rank = shape.len();
132 let mut data = vec![T::zero(); new_total];
133
134 let src_total = self.len();
135 if src_total == 0 || rank == 0 {
136 if src_total == 1 {
137 data[0] = self.storage().data()[0].clone();
138 }
139 return DenseTensorData::from_raw_parts(data, new_shape, order);
140 }
141
142 let inner_axis = match order {
143 MemoryOrder::RowMajor => rank - 1,
144 MemoryOrder::ColumnMajor => 0,
145 };
146 let no_inner_pad = padding[inner_axis] == (0, 0);
147 let src_strides = strides_for(shape, order);
148
149 if no_inner_pad {
150 let raw = self.storage().data();
151 let strip_len = shape[inner_axis];
152 let outer_axes: Vec<usize> = match order {
153 MemoryOrder::RowMajor => (0..rank - 1).collect(),
154 MemoryOrder::ColumnMajor => (1..rank).rev().collect(),
155 };
156 let num_strips = src_total / strip_len.max(1);
157 let mut src_offset = 0usize;
158 let mut dst_flat: usize = (0..rank).map(|d| padding[d].0 * dst_strides[d]).sum();
159 let mut outer_coords = vec![0usize; rank];
160
161 for _ in 0..num_strips {
162 data[dst_flat..dst_flat + strip_len]
163 .clone_from_slice(&raw[src_offset..src_offset + strip_len]);
164 src_offset += strip_len;
165 for &d in outer_axes.iter().rev() {
166 outer_coords[d] += 1;
167 dst_flat += dst_strides[d];
168 if outer_coords[d] < shape[d] {
169 break;
170 }
171 dst_flat -= shape[d] * dst_strides[d];
172 outer_coords[d] = 0;
173 }
174 }
175 return DenseTensorData::from_raw_parts(data, new_shape, order);
176 }
177
178 let raw = self.storage().data();
179 let mut coords = vec![0usize; rank];
180 let axis_order: Vec<usize> = match order {
181 MemoryOrder::RowMajor => (0..rank).collect(),
182 MemoryOrder::ColumnMajor => (0..rank).rev().collect(),
183 };
184 let mut src_flat: usize = 0;
185 let mut dst_flat: usize = (0..rank).map(|d| padding[d].0 * dst_strides[d]).sum();
186
187 for _ in 0..src_total {
188 data[dst_flat] = raw[src_flat].clone();
189 for &d in axis_order.iter().rev() {
190 coords[d] += 1;
191 src_flat += src_strides[d];
192 dst_flat += dst_strides[d];
193 if coords[d] < shape[d] {
194 break;
195 }
196 src_flat -= shape[d] * src_strides[d];
197 dst_flat -= shape[d] * dst_strides[d];
198 coords[d] = 0;
199 }
200 }
201
202 DenseTensorData::from_raw_parts(data, new_shape, order)
203 }
204
205 pub fn replace_slice(&mut self, sub: &Self, begin: &[usize]) {
217 let shape: Vec<usize> = self.shape().to_vec();
218 let sub_shape = sub.shape();
219 assert_eq!(
220 sub_shape.len(),
221 shape.len(),
222 "replace_slice: sub rank {} doesn't match rank {}",
223 sub_shape.len(),
224 shape.len()
225 );
226 assert_eq!(
227 begin.len(),
228 shape.len(),
229 "replace_slice: begin length {} doesn't match rank {}",
230 begin.len(),
231 shape.len()
232 );
233 for (d, (&b, &ss)) in begin.iter().zip(sub_shape).enumerate() {
234 assert!(
235 b + ss <= shape[d],
236 "replace_slice: sub-tensor exceeds boundary on axis {d} ({b} + {ss} > {})",
237 shape[d]
238 );
239 }
240
241 let rank = shape.len();
242 let sub_total = sub.len();
243 if sub_total == 0 {
244 return;
245 }
246
247 if rank == 0 {
248 self.storage_mut().data_mut()[0] = sub.storage().data()[0].clone();
249 return;
250 }
251
252 let order = self.order();
253 if rank >= 2 {
254 assert_eq!(
255 sub.order(),
256 order,
257 "replace_slice: sub.order() ({:?}) must equal self.order() ({:?}) at rank >= 2",
258 sub.order(),
259 order,
260 );
261 }
262
263 let inner_axis = match order {
264 MemoryOrder::RowMajor => rank - 1,
265 MemoryOrder::ColumnMajor => 0,
266 };
267 let self_strides = strides_for(&shape, order);
268 let sub_raw = sub.storage().data();
269 let strip_len = sub_shape[inner_axis];
270 let num_strips = sub_total / strip_len.max(1);
271 let outer_axes: Vec<usize> = match order {
272 MemoryOrder::RowMajor => (0..rank - 1).collect(),
273 MemoryOrder::ColumnMajor => (1..rank).rev().collect(),
274 };
275
276 let dst_buf = self.storage_mut().data_mut();
277 let mut src_offset = 0usize;
278 let mut dst_flat: usize = begin.iter().zip(&self_strides).map(|(&b, &s)| b * s).sum();
279 let mut outer_coords = vec![0usize; rank];
280
281 for _ in 0..num_strips {
282 dst_buf[dst_flat..dst_flat + strip_len]
283 .clone_from_slice(&sub_raw[src_offset..src_offset + strip_len]);
284 src_offset += strip_len;
285
286 for &d in outer_axes.iter().rev() {
287 outer_coords[d] += 1;
288 dst_flat += self_strides[d];
289 if outer_coords[d] < sub_shape[d] {
290 break;
291 }
292 dst_flat -= sub_shape[d] * self_strides[d];
293 outer_coords[d] = 0;
294 }
295 }
296 }
297}