1use crate::prelude_dev::*;
2
3#[non_exhaustive]
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum Indexer {
6 Slice(SliceI),
9 Select(isize),
11 Insert,
14 Ellipsis,
16}
17
18pub use Indexer::Ellipsis;
19pub use Indexer::Insert as NewAxis;
20
21impl<R> From<R> for Indexer
24where
25 R: Into<SliceI>,
26{
27 fn from(slice: R) -> Self {
28 Self::Slice(slice.into())
29 }
30}
31
32impl From<Option<usize>> for Indexer {
33 fn from(opt: Option<usize>) -> Self {
34 match opt {
35 Some(_) => panic!("Option<T> should not be used in Indexer."),
36 None => Self::Insert,
37 }
38 }
39}
40
41macro_rules! impl_from_int_into_indexer {
42 ($($t:ty),*) => {
43 $(
44 impl From<$t> for Indexer {
45 fn from(index: $t) -> Self {
46 Self::Select(index as isize)
47 }
48 }
49 )*
50 };
51}
52
53impl_from_int_into_indexer!(usize, isize, u32, i32, u64, i64);
54
55macro_rules! impl_into_axes_index {
60 ($($t:ty),*) => {
61 $(
62 impl From<$t> for AxesIndex<Indexer> {
63 fn from(index: $t) -> Self {
64 AxesIndex::Val(index.into())
65 }
66 }
67
68 impl<const N: usize> From<[$t; N]> for AxesIndex<Indexer> {
69 fn from(index: [$t; N]) -> Self {
70 let index = index.iter().map(|v| v.clone().into()).collect::<Vec<_>>();
71 AxesIndex::Vec(index)
72 }
73 }
74
75 impl From<Vec<$t>> for AxesIndex<Indexer> {
76 fn from(index: Vec<$t>) -> Self {
77 let index = index.iter().map(|v| v.clone().into()).collect::<Vec<_>>();
78 AxesIndex::Vec(index)
79 }
80 }
81 )*
82 };
83}
84
85impl_into_axes_index!(usize, isize, u32, i32, u64, i64);
86impl_into_axes_index!(Option<usize>);
87impl_into_axes_index!(
88 Slice<isize>,
89 core::ops::Range<isize>,
90 core::ops::RangeFrom<isize>,
91 core::ops::RangeTo<isize>,
92 core::ops::Range<usize>,
93 core::ops::RangeFrom<usize>,
94 core::ops::RangeTo<usize>,
95 core::ops::Range<i32>,
96 core::ops::RangeFrom<i32>,
97 core::ops::RangeTo<i32>,
98 core::ops::RangeFull
99);
100
101impl_from_tuple_to_axes_index!(Indexer);
102
103pub trait IndexerPreserveAPI: Sized {
106 fn dim_narrow(&self, axis: isize, slice: SliceI) -> Result<Self>;
108}
109
110impl<D> IndexerPreserveAPI for Layout<D>
111where
112 D: DimDevAPI,
113{
114 fn dim_narrow(&self, axis: isize, slice: SliceI) -> Result<Self> {
115 let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
117 rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
118 let axis = axis as usize;
119
120 let mut shape = self.shape().clone();
122 let mut stride = self.stride().clone();
123
124 if slice == Slice::new(None, None, None) {
126 return Ok(self.clone());
127 }
128
129 let len_prev = shape[axis] as isize;
131
132 let step = slice.step().unwrap_or(1);
134 rstsr_assert!(step != 0, InvalidValue)?;
135
136 if len_prev == 0 {
138 return Ok(self.clone());
139 }
140
141 if step > 0 {
142 let mut start = slice.start().unwrap_or(0);
144 let mut stop = slice.stop().unwrap_or(len_prev);
145
146 if start < 0 {
148 start = (len_prev + start).max(0);
149 }
150 if stop < 0 {
151 stop = (len_prev + stop).max(0);
152 }
153
154 if start > len_prev || start > stop {
155 start = 0;
157 stop = 0;
158 } else if stop > len_prev {
159 stop = len_prev;
161 }
162
163 let offset = (self.offset() as isize + stride[axis] * start) as usize;
164 shape[axis] = ((stop - start + step - 1) / step).max(0) as usize;
165 stride[axis] *= step;
166 return Self::new(shape, stride, offset);
167 } else {
168 let mut start = slice.start().unwrap_or(len_prev - 1);
171 let mut stop = slice.stop().unwrap_or(-1);
172
173 if start < 0 {
175 start = (len_prev + start).max(0);
176 }
177 if stop < -1 {
178 stop = (len_prev + stop).max(-1);
179 }
180
181 if stop > len_prev - 1 || stop > start {
182 start = 0;
184 stop = 0;
185 } else if start > len_prev - 1 {
186 start = len_prev - 1;
188 }
189
190 let offset = (self.offset() as isize + stride[axis] * start) as usize;
191 shape[axis] = ((stop - start + step + 1) / step).max(0) as usize;
192 stride[axis] *= step;
193 return Self::new(shape, stride, offset);
194 }
195 }
196}
197
198pub trait IndexerSmallerOneAPI {
199 type DOut: DimDevAPI;
200
201 fn dim_select(&self, axis: isize, index: isize) -> Result<Layout<Self::DOut>>;
203
204 fn dim_eliminate(&self, axis: isize) -> Result<Layout<Self::DOut>>;
206}
207
208impl<D> IndexerSmallerOneAPI for Layout<D>
209where
210 D: DimDevAPI + DimSmallerOneAPI,
211 D::SmallerOne: DimDevAPI,
212{
213 type DOut = <D as DimSmallerOneAPI>::SmallerOne;
214
215 fn dim_select(&self, axis: isize, index: isize) -> Result<Layout<Self::DOut>> {
216 let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
218 rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
219 let axis = axis as usize;
220
221 let shape = self.shape();
223 let stride = self.stride();
224 let mut offset = self.offset() as isize;
225 let mut shape_new = vec![];
226 let mut stride_new = vec![];
227
228 for (i, (&d, &s)) in shape.as_ref().iter().zip(stride.as_ref().iter()).enumerate() {
230 if i == axis {
231 let idx = if index < 0 { d as isize + index } else { index };
233 rstsr_pattern!(idx, 0..d as isize, ValueOutOfRange)?;
234 offset += s * idx;
235 } else {
236 shape_new.push(d);
238 stride_new.push(s);
239 }
240 }
241
242 let offset = offset as usize;
243 let layout = Layout::<IxD>::new(shape_new, stride_new, offset)?;
244 return layout.into_dim();
245 }
246
247 fn dim_eliminate(&self, axis: isize) -> Result<Layout<Self::DOut>> {
248 let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
250 rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
251 let axis = axis as usize;
252
253 let mut shape = self.shape().as_ref().to_vec();
255 let mut stride = self.stride().as_ref().to_vec();
256 let offset = self.offset();
257
258 if shape[axis] != 1 {
259 rstsr_raise!(InvalidValue, "Dimension to be eliminated is not 1.")?;
260 }
261
262 shape.remove(axis);
263 stride.remove(axis);
264
265 let layout = Layout::<IxD>::new(shape, stride, offset)?;
266 return layout.into_dim();
267 }
268}
269
270pub trait IndexerLargerOneAPI {
271 type DOut: DimDevAPI;
272
273 fn dim_insert(&self, axis: isize) -> Result<Layout<Self::DOut>>;
276}
277
278impl<D> IndexerLargerOneAPI for Layout<D>
279where
280 D: DimDevAPI + DimLargerOneAPI,
281 D::LargerOne: DimDevAPI,
282{
283 type DOut = <D as DimLargerOneAPI>::LargerOne;
284
285 fn dim_insert(&self, axis: isize) -> Result<Layout<Self::DOut>> {
286 let axis = if axis < 0 { self.ndim() as isize + axis + 1 } else { axis };
288 rstsr_pattern!(axis, 0..(self.ndim() + 1) as isize, ValueOutOfRange)?;
289 let axis = axis as usize;
290
291 let is_f_prefer = self.f_prefer();
293 let mut shape = self.shape().as_ref().to_vec();
294 let mut stride = self.stride().as_ref().to_vec();
295 let offset = self.offset();
296
297 if is_f_prefer {
298 if axis == 0 {
299 shape.insert(0, 1);
300 stride.insert(0, 1);
301 } else {
302 shape.insert(axis, 1);
303 stride.insert(axis, stride[axis - 1]);
304 }
305 } else if axis == self.ndim() {
306 shape.push(1);
307 stride.push(1);
308 } else {
309 shape.insert(axis, 1);
310 stride.insert(axis, stride[axis]);
311 }
312
313 let layout = Layout::new(shape, stride, offset)?;
314 return layout.into_dim();
315 }
316}
317
318pub trait IndexerDynamicAPI: IndexerPreserveAPI {
319 fn dim_slice(&self, indexers: &[Indexer]) -> Result<Layout<IxD>>;
321
322 fn dim_split_at(&self, axis: isize) -> Result<(Layout<IxD>, Layout<IxD>)>;
324
325 fn dim_split_axes(&self, axes: &[isize]) -> Result<(Layout<IxD>, Layout<IxD>)>;
326}
327
328impl<D> IndexerDynamicAPI for Layout<D>
329where
330 D: DimDevAPI,
331{
332 fn dim_slice(&self, indexers: &[Indexer]) -> Result<Layout<IxD>> {
333 let shape = self.shape().as_ref().to_vec();
335 let stride = self.stride().as_ref().to_vec();
336 let mut layout = Layout::new(shape, stride, self.offset)?;
337
338 let mut indexers = indexers.to_vec();
340
341 let mut counter_slice = 0;
343 let mut counter_select = 0;
344 let mut idx_ellipsis = None;
345 for (n, indexer) in indexers.iter().enumerate() {
346 match indexer {
347 Indexer::Slice(_) => counter_slice += 1,
348 Indexer::Select(_) => counter_select += 1,
349 Indexer::Ellipsis => match idx_ellipsis {
350 Some(_) => rstsr_raise!(InvalidValue, "Only one ellipsis indexer allowed.")?,
351 None => idx_ellipsis = Some(n),
352 },
353 _ => {},
354 }
355 }
356
357 rstsr_pattern!(counter_slice + counter_select, 0..=self.ndim(), ValueOutOfRange)?;
359
360 let n_ellipsis = self.ndim() - counter_slice - counter_select;
362 if n_ellipsis == 0 {
363 if let Some(idx) = idx_ellipsis {
364 indexers.remove(idx);
365 }
366 } else if let Some(idx_ellipsis) = idx_ellipsis {
367 indexers[idx_ellipsis] = SliceI::new(None, None, None).into();
368 if n_ellipsis > 1 {
369 for _ in 1..n_ellipsis {
370 indexers.insert(idx_ellipsis, SliceI::new(None, None, None).into());
371 }
372 }
373 } else {
374 for _ in 0..n_ellipsis {
375 indexers.push(SliceI::new(None, None, None).into());
376 }
377 }
378
379 let mut cur_dim = self.ndim() as isize;
382 for indexer in indexers.iter().rev() {
383 match indexer {
384 Indexer::Slice(slice) => {
385 cur_dim -= 1;
386 layout = layout.dim_narrow(cur_dim, *slice)?;
387 },
388 Indexer::Select(index) => {
389 cur_dim -= 1;
390 layout = layout.dim_select(cur_dim, *index)?;
391 },
392 Indexer::Insert => {
393 layout = layout.dim_insert(cur_dim)?;
394 },
395 _ => rstsr_raise!(InvalidValue, "Invalid indexer found : {:?}", indexer)?,
396 }
397 }
398
399 rstsr_assert!(cur_dim == 0, Miscellaneous, "Internal program error in indexer.")?;
401
402 return Ok(layout);
403 }
404
405 fn dim_split_at(&self, axis: isize) -> Result<(Layout<IxD>, Layout<IxD>)> {
406 let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
409 rstsr_pattern!(axis, 0..=self.ndim() as isize, ValueOutOfRange)?;
410 let axis = axis as usize;
411
412 let shape = self.shape().as_ref().to_vec();
414 let stride = self.stride().as_ref().to_vec();
415 let offset = self.offset();
416
417 let (shape1, shape2) = shape.split_at(axis);
418 let (stride1, stride2) = stride.split_at(axis);
419
420 let layout1 = unsafe { Layout::new_unchecked(shape1.to_vec(), stride1.to_vec(), offset) };
421 let layout2 = unsafe { Layout::new_unchecked(shape2.to_vec(), stride2.to_vec(), offset) };
422 return Ok((layout1, layout2));
423 }
424
425 fn dim_split_axes(&self, axes: &[isize]) -> Result<(Layout<IxD>, Layout<IxD>)> {
426 let mut axes_update: Vec<usize> = vec![];
431 for &axis in axes {
432 let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
433 rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
434 axes_update.push(axis as usize);
435 }
436
437 let axes_check = axes_update.clone();
439 axes_update.sort();
440 axes_update.dedup();
441 rstsr_assert_eq!(
442 axes_update.len(),
443 axes_check.len(),
444 InvalidLayout,
445 "Same axis is not allowed for this function."
446 )?;
447
448 let axes_rest =
451 (0..self.ndim()).filter(|&axis| !axes_update.contains(&axis)).collect::<Vec<_>>();
452
453 let offset = self.offset();
455 let shape_axes = axes_update.iter().map(|&axis| self.shape()[axis]).collect::<Vec<_>>();
456 let strides_axes = axes_update.iter().map(|&axis| self.stride()[axis]).collect::<Vec<_>>();
457 let layout_axes = Layout::new(shape_axes, strides_axes, offset)?;
458
459 let shape_rest = axes_rest.iter().map(|&axis| self.shape()[axis]).collect::<Vec<_>>();
460 let strides_rest = axes_rest.iter().map(|&axis| self.stride()[axis]).collect::<Vec<_>>();
461 let layout_rest = Layout::new(shape_rest, strides_rest, offset)?;
462
463 return Ok((layout_axes, layout_rest));
464 }
465}
466
467#[macro_export]
469macro_rules! slice {
470 ($stop:expr) => {{
471 use $crate::layout::slice::Slice;
472 Slice::<isize>::from(Slice::new(None, $stop, None))
473 }};
474 ($start:expr, $stop:expr) => {{
475 use $crate::layout::slice::Slice;
476 Slice::<isize>::from(Slice::new($start, $stop, None))
477 }};
478 ($start:expr, $stop:expr, $step:expr) => {{
479 use $crate::layout::slice::Slice;
480 Slice::<isize>::from(Slice::new($start, $stop, $step))
481 }};
482}
483
484#[macro_export]
485macro_rules! s {
486 [$($slc:expr),*] => {
488 [$(($slc).into()),*].as_ref()
489 };
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
497 fn test_slice() {
498 let t = 3_usize;
499 let s = slice!(1, 2, t);
500 assert_eq!(s.start(), Some(1));
501 assert_eq!(s.stop(), Some(2));
502 assert_eq!(s.step(), Some(3));
503 }
504
505 #[test]
506 fn test_slice_at_dim() {
507 let l = Layout::new([2, 3, 4], [1, 10, 100], 0).unwrap();
508 let s = slice!(10, 1, -1);
509 let l1 = l.dim_narrow(1, s).unwrap();
510 println!("{:?}", l1);
511 let l2 = l.dim_select(1, -2).unwrap();
512 println!("{:?}", l2);
513 let l3 = l.dim_insert(1).unwrap();
514 println!("{:?}", l3);
515
516 let l = Layout::new([2, 3, 4], [100, 10, 1], 0).unwrap();
517 let l3 = l.dim_insert(1).unwrap();
518 println!("{:?}", l3);
519
520 let l4 = l.dim_slice(s![Indexer::Ellipsis, 1..3, None, 2]).unwrap();
521 let l4 = l4.into_dim::<Ix3>().unwrap();
522 println!("{:?}", l4);
523 assert_eq!(l4.shape(), &[2, 2, 1]);
524 assert_eq!(l4.offset(), 12);
525
526 let l5 = l.dim_slice(s![None, 1, None, 1..3]).unwrap();
527 let l5 = l5.into_dim::<Ix4>().unwrap();
528 println!("{:?}", l5);
529 assert_eq!(l5.shape(), &[1, 1, 2, 4]);
530 assert_eq!(l5.offset(), 110);
531 }
532
533 #[test]
534 fn test_slice_with_stride() {
535 let l = Layout::new([24], [1], 0).unwrap();
536 let b = l.dim_narrow(0, slice!(5, 15, 2)).unwrap();
537 assert_eq!(b, Layout::new([5], [2], 5).unwrap());
538 let b = l.dim_narrow(0, slice!(5, 16, 2)).unwrap();
539 assert_eq!(b, Layout::new([6], [2], 5).unwrap());
540 let b = l.dim_narrow(0, slice!(15, 5, -2)).unwrap();
541 assert_eq!(b, Layout::new([5], [-2], 15).unwrap());
542 let b = l.dim_narrow(0, slice!(15, 4, -2)).unwrap();
543 assert_eq!(b, Layout::new([6], [-2], 15).unwrap());
544 }
545
546 #[test]
547 fn test_expand_dims() {
548 let l = Layout::<Ix3>::new([2, 3, 4], [1, 10, 100], 0).unwrap();
549 let l1 = l.dim_insert(0).unwrap();
550 println!("{:?}", l1);
551 let l2 = l.dim_insert(1).unwrap();
552 println!("{:?}", l2);
553 let l3 = l.dim_insert(3).unwrap();
554 println!("{:?}", l3);
555 let l4 = l.dim_insert(-1).unwrap();
556 println!("{:?}", l4);
557 let l5 = l.dim_insert(-4).unwrap();
558 println!("{:?}", l5);
559 }
560}