1use crate::prelude_dev::*;
2
3#[non_exhaustive]
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum Indexer {
6 Slice(SliceI),
9 Select(isize),
11 Insert,
14 Ellipsis,
16}
17
18pub use Indexer::Ellipsis;
19pub use Indexer::Insert as NewAxis;
20
21impl<R> From<R> for Indexer
24where
25 R: Into<SliceI>,
26{
27 fn from(slice: R) -> Self {
28 Self::Slice(slice.into())
29 }
30}
31
32impl From<Option<usize>> for Indexer {
33 fn from(opt: Option<usize>) -> Self {
34 match opt {
35 Some(_) => panic!("Option<T> should not be used in Indexer."),
36 None => Self::Insert,
37 }
38 }
39}
40
41macro_rules! impl_from_int_into_indexer {
42 ($($t:ty),*) => {
43 $(
44 impl From<$t> for Indexer {
45 fn from(index: $t) -> Self {
46 Self::Select(index as isize)
47 }
48 }
49 )*
50 };
51}
52
53impl_from_int_into_indexer!(usize, isize, u32, i32, u64, i64);
54
55macro_rules! impl_into_axes_index {
60 ($($t:ty),*) => {
61 $(
62 impl TryFrom<$t> for AxesIndex<Indexer> {
63 type Error = Error;
64
65 fn try_from(index: $t) -> Result<Self> {
66 Ok(AxesIndex::Val(index.try_into()?))
67 }
68 }
69
70 impl<const N: usize> TryFrom<[$t; N]> for AxesIndex<Indexer> {
71 type Error = Error;
72
73 fn try_from(index: [$t; N]) -> Result<Self> {
74 let index = index.iter().map(|v| v.clone().into()).collect::<Vec<_>>();
75 Ok(AxesIndex::Vec(index))
76 }
77 }
78
79 impl TryFrom<Vec<$t>> for AxesIndex<Indexer> {
80 type Error = Error;
81
82 fn try_from(index: Vec<$t>) -> Result<Self> {
83 let index = index.iter().map(|v| v.clone().into()).collect::<Vec<_>>();
84 Ok(AxesIndex::Vec(index))
85 }
86 }
87 )*
88 };
89}
90
91impl_into_axes_index!(usize, isize, u32, i32, u64, i64);
92impl_into_axes_index!(Option<usize>);
93impl_into_axes_index!(
94 Slice<isize>,
95 core::ops::Range<isize>,
96 core::ops::RangeFrom<isize>,
97 core::ops::RangeTo<isize>,
98 core::ops::Range<usize>,
99 core::ops::RangeFrom<usize>,
100 core::ops::RangeTo<usize>,
101 core::ops::Range<i32>,
102 core::ops::RangeFrom<i32>,
103 core::ops::RangeTo<i32>,
104 core::ops::RangeFull
105);
106
107impl_from_tuple_to_axes_index!(Indexer);
108
109pub trait IndexerPreserveAPI: Sized {
112 fn dim_narrow(&self, axis: isize, slice: SliceI) -> Result<Self>;
114}
115
116impl<D> IndexerPreserveAPI for Layout<D>
117where
118 D: DimDevAPI,
119{
120 fn dim_narrow(&self, axis: isize, slice: SliceI) -> Result<Self> {
121 let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
123 rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
124 let axis = axis as usize;
125
126 let mut shape = self.shape().clone();
128 let mut stride = self.stride().clone();
129
130 if slice == Slice::new(None, None, None) {
132 return Ok(self.clone());
133 }
134
135 let len_prev = shape[axis] as isize;
137
138 let step = slice.step().unwrap_or(1);
140 rstsr_assert!(step != 0, InvalidValue)?;
141
142 if len_prev == 0 {
144 return Ok(self.clone());
145 }
146
147 if step > 0 {
148 let mut start = slice.start().unwrap_or(0);
150 let mut stop = slice.stop().unwrap_or(len_prev);
151
152 if start < 0 {
154 start = (len_prev + start).max(0);
155 }
156 if stop < 0 {
157 stop = (len_prev + stop).max(0);
158 }
159
160 if start > len_prev || start > stop {
161 start = 0;
163 stop = 0;
164 } else if stop > len_prev {
165 stop = len_prev;
167 }
168
169 let offset = (self.offset() as isize + stride[axis] * start) as usize;
170 shape[axis] = ((stop - start + step - 1) / step).max(0) as usize;
171 stride[axis] *= step;
172 return Self::new(shape, stride, offset);
173 } else {
174 let mut start = slice.start().unwrap_or(len_prev - 1);
177 let mut stop = slice.stop().unwrap_or(-1);
178
179 if start < 0 {
181 start = (len_prev + start).max(0);
182 }
183 if stop < -1 {
184 stop = (len_prev + stop).max(-1);
185 }
186
187 if stop > len_prev - 1 || stop > start {
188 start = 0;
190 stop = 0;
191 } else if start > len_prev - 1 {
192 start = len_prev - 1;
194 }
195
196 let offset = (self.offset() as isize + stride[axis] * start) as usize;
197 shape[axis] = ((stop - start + step + 1) / step).max(0) as usize;
198 stride[axis] *= step;
199 return Self::new(shape, stride, offset);
200 }
201 }
202}
203
204pub trait IndexerSmallerOneAPI {
205 type DOut: DimDevAPI;
206
207 fn dim_select(&self, axis: isize, index: isize) -> Result<Layout<Self::DOut>>;
209
210 fn dim_eliminate(&self, axis: isize) -> Result<Layout<Self::DOut>>;
215
216 fn dim_chop(&self, axis: isize) -> Result<Layout<Self::DOut>>;
219}
220
221impl<D> IndexerSmallerOneAPI for Layout<D>
222where
223 D: DimDevAPI + DimSmallerOneAPI,
224 D::SmallerOne: DimDevAPI,
225{
226 type DOut = <D as DimSmallerOneAPI>::SmallerOne;
227
228 fn dim_select(&self, axis: isize, index: isize) -> Result<Layout<Self::DOut>> {
229 let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
231 rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
232 let axis = axis as usize;
233
234 let shape = self.shape();
236 let stride = self.stride();
237 let mut offset = self.offset() as isize;
238 let mut shape_new = vec![];
239 let mut stride_new = vec![];
240
241 for (i, (&d, &s)) in shape.as_ref().iter().zip(stride.as_ref().iter()).enumerate() {
243 if i == axis {
244 let idx = if index < 0 { d as isize + index } else { index };
246 rstsr_pattern!(idx, 0..d as isize, ValueOutOfRange)?;
247 offset += s * idx;
248 } else {
249 shape_new.push(d);
251 stride_new.push(s);
252 }
253 }
254
255 let offset = offset as usize;
256 let layout = Layout::<IxD>::new(shape_new, stride_new, offset)?;
257 return layout.into_dim();
258 }
259
260 fn dim_eliminate(&self, axis: isize) -> Result<Layout<Self::DOut>> {
261 let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
263 rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
264 let axis = axis as usize;
265
266 let mut shape = self.shape().as_ref().to_vec();
268 let mut stride = self.stride().as_ref().to_vec();
269 let offset = self.offset();
270
271 if shape[axis] != 1 {
272 rstsr_raise!(InvalidValue, "Dimension to be eliminated is not 1.")?;
273 }
274
275 shape.remove(axis);
276 stride.remove(axis);
277
278 let layout = Layout::<IxD>::new(shape, stride, offset)?;
279 return layout.into_dim();
280 }
281
282 fn dim_chop(&self, axis: isize) -> Result<Layout<Self::DOut>> {
283 let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
285 rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
286 let axis = axis as usize;
287
288 let mut shape = self.shape().as_ref().to_vec();
290 let mut stride = self.stride().as_ref().to_vec();
291 let offset = self.offset();
292
293 shape.remove(axis);
294 stride.remove(axis);
295
296 let layout = Layout::<IxD>::new(shape, stride, offset)?;
297 return layout.into_dim();
298 }
299}
300
301pub trait IndexerLargerOneAPI {
302 type DOut: DimDevAPI;
303
304 fn dim_insert(&self, axis: isize) -> Result<Layout<Self::DOut>>;
307}
308
309impl<D> IndexerLargerOneAPI for Layout<D>
310where
311 D: DimDevAPI + DimLargerOneAPI,
312 D::LargerOne: DimDevAPI,
313{
314 type DOut = <D as DimLargerOneAPI>::LargerOne;
315
316 fn dim_insert(&self, axis: isize) -> Result<Layout<Self::DOut>> {
317 let axis = if axis < 0 { self.ndim() as isize + axis + 1 } else { axis };
319 rstsr_pattern!(axis, 0..(self.ndim() + 1) as isize, ValueOutOfRange)?;
320 let axis = axis as usize;
321
322 let is_f_prefer = self.f_prefer();
324 let mut shape = self.shape().as_ref().to_vec();
325 let mut stride = self.stride().as_ref().to_vec();
326 let offset = self.offset();
327
328 if is_f_prefer {
329 if axis == 0 {
330 shape.insert(0, 1);
331 stride.insert(0, 1);
332 } else {
333 shape.insert(axis, 1);
334 stride.insert(axis, stride[axis - 1]);
335 }
336 } else if axis == self.ndim() {
337 shape.push(1);
338 stride.push(1);
339 } else {
340 shape.insert(axis, 1);
341 stride.insert(axis, stride[axis]);
342 }
343
344 let layout = Layout::new(shape, stride, offset)?;
345 return layout.into_dim();
346 }
347}
348
349pub trait IndexerDynamicAPI: IndexerPreserveAPI {
350 fn dim_slice(&self, indexers: &[Indexer]) -> Result<Layout<IxD>>;
352
353 fn dim_split_at(&self, axis: isize) -> Result<(Layout<IxD>, Layout<IxD>)>;
355
356 fn dim_split_axes(&self, axes: &[isize]) -> Result<(Layout<IxD>, Layout<IxD>)>;
362}
363
364impl<D> IndexerDynamicAPI for Layout<D>
365where
366 D: DimDevAPI,
367{
368 fn dim_slice(&self, indexers: &[Indexer]) -> Result<Layout<IxD>> {
369 let shape = self.shape().as_ref().to_vec();
371 let stride = self.stride().as_ref().to_vec();
372 let mut layout = Layout::new(shape, stride, self.offset)?;
373
374 let mut indexers = indexers.to_vec();
376
377 let mut counter_slice = 0;
379 let mut counter_select = 0;
380 let mut idx_ellipsis = None;
381 for (n, indexer) in indexers.iter().enumerate() {
382 match indexer {
383 Indexer::Slice(_) => counter_slice += 1,
384 Indexer::Select(_) => counter_select += 1,
385 Indexer::Ellipsis => match idx_ellipsis {
386 Some(_) => rstsr_raise!(InvalidValue, "Only one ellipsis indexer allowed.")?,
387 None => idx_ellipsis = Some(n),
388 },
389 _ => {},
390 }
391 }
392
393 rstsr_pattern!(counter_slice + counter_select, 0..=self.ndim(), ValueOutOfRange)?;
395
396 let n_ellipsis = self.ndim() - counter_slice - counter_select;
398 if n_ellipsis == 0 {
399 if let Some(idx) = idx_ellipsis {
400 indexers.remove(idx);
401 }
402 } else if let Some(idx_ellipsis) = idx_ellipsis {
403 indexers[idx_ellipsis] = SliceI::new(None, None, None).into();
404 if n_ellipsis > 1 {
405 for _ in 1..n_ellipsis {
406 indexers.insert(idx_ellipsis, SliceI::new(None, None, None).into());
407 }
408 }
409 } else {
410 for _ in 0..n_ellipsis {
411 indexers.push(SliceI::new(None, None, None).into());
412 }
413 }
414
415 let mut cur_dim = self.ndim() as isize;
418 for indexer in indexers.iter().rev() {
419 match indexer {
420 Indexer::Slice(slice) => {
421 cur_dim -= 1;
422 layout = layout.dim_narrow(cur_dim, *slice)?;
423 },
424 Indexer::Select(index) => {
425 cur_dim -= 1;
426 layout = layout.dim_select(cur_dim, *index)?;
427 },
428 Indexer::Insert => {
429 layout = layout.dim_insert(cur_dim)?;
430 },
431 _ => rstsr_raise!(InvalidValue, "Invalid indexer found : {:?}", indexer)?,
432 }
433 }
434
435 rstsr_assert!(cur_dim == 0, Miscellaneous, "Internal program error in indexer.")?;
437
438 return Ok(layout);
439 }
440
441 fn dim_split_at(&self, axis: isize) -> Result<(Layout<IxD>, Layout<IxD>)> {
442 let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
445 rstsr_pattern!(axis, 0..=self.ndim() as isize, ValueOutOfRange)?;
446 let axis = axis as usize;
447
448 let shape = self.shape().as_ref().to_vec();
450 let stride = self.stride().as_ref().to_vec();
451 let offset = self.offset();
452
453 let (shape1, shape2) = shape.split_at(axis);
454 let (stride1, stride2) = stride.split_at(axis);
455
456 let layout1 = unsafe { Layout::new_unchecked(shape1.to_vec(), stride1.to_vec(), offset) };
457 let layout2 = unsafe { Layout::new_unchecked(shape2.to_vec(), stride2.to_vec(), offset) };
458 return Ok((layout1, layout2));
459 }
460
461 fn dim_split_axes(&self, axes: &[isize]) -> Result<(Layout<IxD>, Layout<IxD>)> {
462 let axes_update = normalize_axes_index(axes.into(), self.ndim(), false, false)?
466 .into_iter()
467 .map(|axis| axis as usize)
468 .collect::<Vec<usize>>();
469
470 let axes_rest = (0..self.ndim()).filter(|&axis| !axes_update.contains(&axis)).collect::<Vec<_>>();
473
474 let offset = self.offset();
476 let shape_axes = axes_update.iter().map(|&axis| self.shape()[axis]).collect::<Vec<_>>();
477 let strides_axes = axes_update.iter().map(|&axis| self.stride()[axis]).collect::<Vec<_>>();
478 let layout_axes = Layout::new(shape_axes, strides_axes, offset)?;
479
480 let shape_rest = axes_rest.iter().map(|&axis| self.shape()[axis]).collect::<Vec<_>>();
481 let strides_rest = axes_rest.iter().map(|&axis| self.stride()[axis]).collect::<Vec<_>>();
482 let layout_rest = Layout::new(shape_rest, strides_rest, offset)?;
483
484 return Ok((layout_axes, layout_rest));
485 }
486}
487
488#[macro_export]
490macro_rules! slice {
491 ($stop:expr) => {{
492 use $crate::layout::slice::Slice;
493 Slice::<isize>::from(Slice::new(None, $stop, None))
494 }};
495 ($start:expr, $stop:expr) => {{
496 use $crate::layout::slice::Slice;
497 Slice::<isize>::from(Slice::new($start, $stop, None))
498 }};
499 ($start:expr, $stop:expr, $step:expr) => {{
500 use $crate::layout::slice::Slice;
501 Slice::<isize>::from(Slice::new($start, $stop, $step))
502 }};
503}
504
505#[macro_export]
506macro_rules! s {
507 [$($slc:expr),*] => {
509 [$(($slc).into()),*].as_ref()
510 };
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_slice() {
519 let t = 3_usize;
520 let s = slice!(1, 2, t);
521 assert_eq!(s.start(), Some(1));
522 assert_eq!(s.stop(), Some(2));
523 assert_eq!(s.step(), Some(3));
524 }
525
526 #[test]
527 fn test_slice_at_dim() {
528 let l = Layout::new([2, 3, 4], [1, 10, 100], 0).unwrap();
529 let s = slice!(10, 1, -1);
530 let l1 = l.dim_narrow(1, s).unwrap();
531 println!("{l1:?}");
532 let l2 = l.dim_select(1, -2).unwrap();
533 println!("{l2:?}");
534 let l3 = l.dim_insert(1).unwrap();
535 println!("{l3:?}");
536
537 let l = Layout::new([2, 3, 4], [100, 10, 1], 0).unwrap();
538 let l3 = l.dim_insert(1).unwrap();
539 println!("{l3:?}");
540
541 let l4 = l.dim_slice(s![Indexer::Ellipsis, 1..3, None, 2]).unwrap();
542 let l4 = l4.into_dim::<Ix3>().unwrap();
543 println!("{l4:?}");
544 assert_eq!(l4.shape(), &[2, 2, 1]);
545 assert_eq!(l4.offset(), 12);
546
547 let l5 = l.dim_slice(s![None, 1, None, 1..3]).unwrap();
548 let l5 = l5.into_dim::<Ix4>().unwrap();
549 println!("{l5:?}");
550 assert_eq!(l5.shape(), &[1, 1, 2, 4]);
551 assert_eq!(l5.offset(), 110);
552 }
553
554 #[test]
555 fn test_slice_with_stride() {
556 let l = Layout::new([24], [1], 0).unwrap();
557 let b = l.dim_narrow(0, slice!(5, 15, 2)).unwrap();
558 assert_eq!(b, Layout::new([5], [2], 5).unwrap());
559 let b = l.dim_narrow(0, slice!(5, 16, 2)).unwrap();
560 assert_eq!(b, Layout::new([6], [2], 5).unwrap());
561 let b = l.dim_narrow(0, slice!(15, 5, -2)).unwrap();
562 assert_eq!(b, Layout::new([5], [-2], 15).unwrap());
563 let b = l.dim_narrow(0, slice!(15, 4, -2)).unwrap();
564 assert_eq!(b, Layout::new([6], [-2], 15).unwrap());
565 }
566
567 #[test]
568 fn test_expand_dims() {
569 let l = Layout::<Ix3>::new([2, 3, 4], [1, 10, 100], 0).unwrap();
570 let l1 = l.dim_insert(0).unwrap();
571 println!("{l1:?}");
572 let l2 = l.dim_insert(1).unwrap();
573 println!("{l2:?}");
574 let l3 = l.dim_insert(3).unwrap();
575 println!("{l3:?}");
576 let l4 = l.dim_insert(-1).unwrap();
577 println!("{l4:?}");
578 let l5 = l.dim_insert(-4).unwrap();
579 println!("{l5:?}");
580 }
581}