1use crate::array::owned::Array;
10use crate::array::view::ArrayView;
11use crate::array::view_mut::ArrayViewMut;
12use crate::dimension::{Axis, Dimension, IxDyn};
13use crate::dtype::Element;
14use crate::error::{FerrayError, FerrayResult};
15
16use super::normalize_index;
17
18#[derive(Debug, Clone, Copy)]
25pub struct SliceSpec {
26 pub start: Option<isize>,
28 pub stop: Option<isize>,
30 pub step: Option<isize>,
32}
33
34impl SliceSpec {
35 #[must_use]
37 pub const fn full() -> Self {
38 Self {
39 start: None,
40 stop: None,
41 step: None,
42 }
43 }
44
45 #[must_use]
47 pub const fn new(start: isize, stop: isize) -> Self {
48 Self {
49 start: Some(start),
50 stop: Some(stop),
51 step: None,
52 }
53 }
54
55 #[must_use]
57 pub const fn with_step(start: isize, stop: isize, step: isize) -> Self {
58 Self {
59 start: Some(start),
60 stop: Some(stop),
61 step: Some(step),
62 }
63 }
64
65 fn validate(&self) -> FerrayResult<()> {
67 if self.step == Some(0) {
68 return Err(FerrayError::invalid_value("slice step cannot be zero"));
69 }
70 Ok(())
71 }
72
73 #[allow(clippy::wrong_self_convention)]
75 fn to_ndarray_slice(&self) -> ndarray::Slice {
76 ndarray::Slice::new(self.start.unwrap_or(0), self.stop, self.step.unwrap_or(1))
77 }
78
79 #[allow(dead_code, clippy::wrong_self_convention)]
81 pub(crate) fn to_ndarray_elem(&self) -> ndarray::SliceInfoElem {
82 ndarray::SliceInfoElem::Slice {
83 start: self.start.unwrap_or(0),
84 end: self.stop,
85 step: self.step.unwrap_or(1),
86 }
87 }
88}
89
90impl<T: Element, D: Dimension> Array<T, D> {
95 pub fn index_axis(&self, axis: Axis, index: isize) -> FerrayResult<ArrayView<'_, T, D::Smaller>>
104 where
105 D::NdarrayDim:
106 ndarray::RemoveAxis<Smaller = <D::Smaller as crate::dimension::Dimension>::NdarrayDim>,
107 {
108 let ndim = self.ndim();
111 let ax = axis.index();
112 if ax >= ndim {
113 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
114 }
115 let size = self.shape()[ax];
116 let idx = normalize_index(index, size, ax)?;
117
118 let nd_axis = ndarray::Axis(ax);
119 let sub = self.inner.index_axis(nd_axis, idx);
120 Ok(ArrayView::from_ndarray(sub))
121 }
122
123 pub fn slice_axis(&self, axis: Axis, spec: SliceSpec) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
131 let ndim = self.ndim();
132 let ax = axis.index();
133 if ax >= ndim {
134 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
135 }
136 spec.validate()?;
137
138 let nd_axis = ndarray::Axis(ax);
139 let nd_slice = spec.to_ndarray_slice();
140 let sliced = self.inner.slice_axis(nd_axis, nd_slice);
141 let dyn_view = sliced.into_dyn();
142 Ok(ArrayView::from_ndarray(dyn_view))
143 }
144
145 pub fn slice_axis_mut(
150 &mut self,
151 axis: Axis,
152 spec: SliceSpec,
153 ) -> FerrayResult<ArrayViewMut<'_, T, IxDyn>> {
154 let ndim = self.ndim();
155 let ax = axis.index();
156 if ax >= ndim {
157 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
158 }
159 spec.validate()?;
160
161 let nd_axis = ndarray::Axis(ax);
162 let nd_slice = spec.to_ndarray_slice();
163 let sliced = self.inner.slice_axis_mut(nd_axis, nd_slice);
164 let dyn_view = sliced.into_dyn();
165 Ok(ArrayViewMut::from_ndarray(dyn_view))
166 }
167
168 pub fn slice_multi(&self, specs: &[SliceSpec]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
177 let ndim = self.ndim();
178 if specs.len() != ndim {
179 return Err(FerrayError::invalid_value(format!(
180 "expected {} slice specs, got {}",
181 ndim,
182 specs.len()
183 )));
184 }
185
186 for spec in specs {
187 spec.validate()?;
188 }
189
190 let mut result = self.inner.view().into_dyn();
192 for (ax, spec) in specs.iter().enumerate() {
193 let nd_axis = ndarray::Axis(ax);
194 let nd_slice = spec.to_ndarray_slice();
195 result = result.slice_axis_move(nd_axis, nd_slice).into_dyn();
196 }
197 Ok(ArrayView::from_ndarray(result))
198 }
199
200 pub fn insert_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'_, T, D::Larger>>
209 where
210 D::NdarrayDim:
211 ndarray::Dimension<Larger = <D::Larger as crate::dimension::Dimension>::NdarrayDim>,
212 {
213 let ndim = self.ndim();
214 let ax = axis.index();
215 if ax > ndim {
216 return Err(FerrayError::axis_out_of_bounds(ax, ndim + 1));
217 }
218 let expanded = self.inner.view().insert_axis(ndarray::Axis(ax));
219 Ok(ArrayView::from_ndarray(expanded))
220 }
221
222 pub fn remove_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'_, T, D::Smaller>>
231 where
232 D::NdarrayDim:
233 ndarray::RemoveAxis<Smaller = <D::Smaller as crate::dimension::Dimension>::NdarrayDim>,
234 {
235 let ndim = self.ndim();
236 let ax = axis.index();
237 if ax >= ndim {
238 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
239 }
240 if self.shape()[ax] != 1 {
241 return Err(FerrayError::invalid_value(format!(
242 "cannot remove axis {} with size {} (must be 1)",
243 ax,
244 self.shape()[ax]
245 )));
246 }
247 let squeezed = self.inner.view().index_axis_move(ndarray::Axis(ax), 0);
248 Ok(ArrayView::from_ndarray(squeezed))
249 }
250
251 pub fn flat_index(&self, index: isize) -> FerrayResult<&T> {
260 let size = self.size();
261 let idx = normalize_index(index, size, 0)?;
262
263 let shape = self.shape();
266 let strides = self.inner.strides();
267 let base_ptr = self.inner.as_ptr();
268 let ndim = shape.len();
269
270 let mut remaining = idx;
271 let mut offset: isize = 0;
272 for d in 0..ndim {
273 let dim_stride: usize = shape[d + 1..].iter().product::<usize>().max(1);
274 let coord = remaining / dim_stride;
275 remaining %= dim_stride;
276 offset += coord as isize * strides[d];
277 }
278
279 Ok(unsafe { &*base_ptr.offset(offset) })
283 }
284
285 pub fn get(&self, indices: &[isize]) -> FerrayResult<&T> {
293 let ndim = self.ndim();
294 if indices.len() != ndim {
295 return Err(FerrayError::invalid_value(format!(
296 "expected {} indices, got {}",
297 ndim,
298 indices.len()
299 )));
300 }
301
302 let shape = self.shape();
304 let strides = self.inner.strides();
305 let base_ptr = self.inner.as_ptr();
306
307 let mut offset: isize = 0;
308 for (ax, &idx) in indices.iter().enumerate() {
309 let pos = normalize_index(idx, shape[ax], ax)?;
310 offset += pos as isize * strides[ax];
311 }
312
313 Ok(unsafe { &*base_ptr.offset(offset) })
316 }
317
318 pub fn get_mut(&mut self, indices: &[isize]) -> FerrayResult<&mut T> {
323 let ndim = self.ndim();
324 if indices.len() != ndim {
325 return Err(FerrayError::invalid_value(format!(
326 "expected {} indices, got {}",
327 ndim,
328 indices.len()
329 )));
330 }
331
332 let shape = self.shape().to_vec();
333 let strides: Vec<isize> = self.inner.strides().to_vec();
334 let base_ptr = self.inner.as_mut_ptr();
335
336 let mut offset: isize = 0;
337 for (ax, &idx) in indices.iter().enumerate() {
338 let pos = normalize_index(idx, shape[ax], ax)?;
339 offset += pos as isize * strides[ax];
340 }
341
342 Ok(unsafe { &mut *base_ptr.offset(offset) })
345 }
346}
347
348impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
353 pub fn index_axis(&self, axis: Axis, index: isize) -> FerrayResult<ArrayView<'a, T, D::Smaller>>
356 where
357 D::NdarrayDim:
358 ndarray::RemoveAxis<Smaller = <D::Smaller as crate::dimension::Dimension>::NdarrayDim>,
359 {
360 let ndim = self.ndim();
361 let ax = axis.index();
362 if ax >= ndim {
363 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
364 }
365 let size = self.shape()[ax];
366 let idx = normalize_index(index, size, ax)?;
367
368 let nd_axis = ndarray::Axis(ax);
369 let sub = self.inner.clone().index_axis_move(nd_axis, idx);
371 Ok(ArrayView::from_ndarray(sub))
372 }
373
374 pub fn slice_axis(&self, axis: Axis, spec: SliceSpec) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
376 let ndim = self.ndim();
377 let ax = axis.index();
378 if ax >= ndim {
379 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
380 }
381 spec.validate()?;
382
383 let nd_axis = ndarray::Axis(ax);
384 let nd_slice = spec.to_ndarray_slice();
385 let sliced = self.inner.clone().slice_axis_move(nd_axis, nd_slice);
387 let dyn_view = sliced.into_dyn();
388 Ok(ArrayView::from_ndarray(dyn_view))
389 }
390
391 pub fn insert_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'a, T, D::Larger>>
394 where
395 D::NdarrayDim:
396 ndarray::Dimension<Larger = <D::Larger as crate::dimension::Dimension>::NdarrayDim>,
397 {
398 let ndim = self.ndim();
399 let ax = axis.index();
400 if ax > ndim {
401 return Err(FerrayError::axis_out_of_bounds(ax, ndim + 1));
402 }
403 let expanded = self.inner.clone().insert_axis(ndarray::Axis(ax));
404 Ok(ArrayView::from_ndarray(expanded))
405 }
406
407 pub fn remove_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'a, T, D::Smaller>>
410 where
411 D::NdarrayDim:
412 ndarray::RemoveAxis<Smaller = <D::Smaller as crate::dimension::Dimension>::NdarrayDim>,
413 {
414 let ndim = self.ndim();
415 let ax = axis.index();
416 if ax >= ndim {
417 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
418 }
419 if self.shape()[ax] != 1 {
420 return Err(FerrayError::invalid_value(format!(
421 "cannot remove axis {} with size {} (must be 1)",
422 ax,
423 self.shape()[ax]
424 )));
425 }
426
427 let squeezed = self.inner.clone().index_axis_move(ndarray::Axis(ax), 0);
428 Ok(ArrayView::from_ndarray(squeezed))
429 }
430
431 pub fn get(&self, indices: &[isize]) -> FerrayResult<&'a T> {
433 let ndim = self.ndim();
434 if indices.len() != ndim {
435 return Err(FerrayError::invalid_value(format!(
436 "expected {} indices, got {}",
437 ndim,
438 indices.len()
439 )));
440 }
441
442 let shape = self.shape();
443 let strides = self.inner.strides();
444 let base_ptr = self.inner.as_ptr();
445
446 let mut offset: isize = 0;
447 for (ax, &idx) in indices.iter().enumerate() {
448 let pos = normalize_index(idx, shape[ax], ax)?;
449 offset += pos as isize * strides[ax];
450 }
451
452 Ok(unsafe { &*base_ptr.offset(offset) })
454 }
455}
456
457impl<T: Element, D: Dimension> ArrayViewMut<'_, T, D> {
462 pub fn slice_axis_mut(
464 &mut self,
465 axis: Axis,
466 spec: SliceSpec,
467 ) -> FerrayResult<ArrayViewMut<'_, T, IxDyn>> {
468 let ndim = self.ndim();
469 let ax = axis.index();
470 if ax >= ndim {
471 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
472 }
473 spec.validate()?;
474
475 let nd_axis = ndarray::Axis(ax);
476 let nd_slice = spec.to_ndarray_slice();
477 let sliced = self.inner.slice_axis_mut(nd_axis, nd_slice);
478 let dyn_view = sliced.into_dyn();
479 Ok(ArrayViewMut::from_ndarray(dyn_view))
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use crate::dimension::{Ix1, Ix2, Ix3};
487
488 #[test]
493 fn normalize_positive_in_bounds() {
494 assert_eq!(normalize_index(2, 5, 0).unwrap(), 2);
495 }
496
497 #[test]
498 fn normalize_negative() {
499 assert_eq!(normalize_index(-1, 5, 0).unwrap(), 4);
500 assert_eq!(normalize_index(-5, 5, 0).unwrap(), 0);
501 }
502
503 #[test]
504 fn normalize_out_of_bounds() {
505 assert!(normalize_index(5, 5, 0).is_err());
506 assert!(normalize_index(-6, 5, 0).is_err());
507 }
508
509 #[test]
514 fn index_axis_row() {
515 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
516 let row = arr.index_axis(Axis(0), 1).unwrap();
517 assert_eq!(row.shape(), &[4]);
518 let data: Vec<i32> = row.iter().copied().collect();
519 assert_eq!(data, vec![4, 5, 6, 7]);
520 }
521
522 #[test]
523 fn index_axis_column() {
524 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
525 let col = arr.index_axis(Axis(1), 2).unwrap();
526 assert_eq!(col.shape(), &[3]);
527 let data: Vec<i32> = col.iter().copied().collect();
528 assert_eq!(data, vec![2, 6, 10]);
529 }
530
531 #[test]
532 fn index_axis_negative() {
533 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
534 let row = arr.index_axis(Axis(0), -1).unwrap();
535 let data: Vec<i32> = row.iter().copied().collect();
536 assert_eq!(data, vec![8, 9, 10, 11]);
537 }
538
539 #[test]
540 fn index_axis_out_of_bounds() {
541 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
542 assert!(arr.index_axis(Axis(0), 3).is_err());
543 assert!(arr.index_axis(Axis(2), 0).is_err());
544 }
545
546 #[test]
547 fn index_axis_is_zero_copy() {
548 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
549 .unwrap();
550 let row = arr.index_axis(Axis(0), 0).unwrap();
551 assert_eq!(row.as_ptr(), arr.as_ptr());
552 }
553
554 #[test]
559 fn slice_axis_basic() {
560 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
561 let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
562 assert_eq!(sliced.shape(), &[3]);
563 let data: Vec<i32> = sliced.iter().copied().collect();
564 assert_eq!(data, vec![20, 30, 40]);
565 }
566
567 #[test]
568 fn slice_axis_step() {
569 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![0, 1, 2, 3, 4, 5]).unwrap();
570 let sliced = arr
571 .slice_axis(Axis(0), SliceSpec::with_step(0, 6, 2))
572 .unwrap();
573 assert_eq!(sliced.shape(), &[3]);
574 let data: Vec<i32> = sliced.iter().copied().collect();
575 assert_eq!(data, vec![0, 2, 4]);
576 }
577
578 #[test]
579 fn slice_axis_negative_step() {
580 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
581 let spec = SliceSpec {
585 start: None,
586 stop: None,
587 step: Some(-1),
588 };
589 let sliced = arr.slice_axis(Axis(0), spec).unwrap();
590 let data: Vec<i32> = sliced.iter().copied().collect();
591 assert_eq!(data, vec![4, 3, 2, 1, 0]);
592 }
593
594 #[test]
595 fn slice_axis_negative_step_partial() {
596 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
597 let sliced = arr
599 .slice_axis(Axis(0), SliceSpec::with_step(1, 4, -1))
600 .unwrap();
601 let data: Vec<i32> = sliced.iter().copied().collect();
602 assert_eq!(data, vec![3, 2, 1]);
603 }
604
605 #[test]
606 fn slice_axis_full() {
607 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
608 let sliced = arr.slice_axis(Axis(0), SliceSpec::full()).unwrap();
609 let data: Vec<i32> = sliced.iter().copied().collect();
610 assert_eq!(data, vec![1, 2, 3]);
611 }
612
613 #[test]
614 fn slice_axis_2d_rows() {
615 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), (0..12).collect()).unwrap();
616 let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 3)).unwrap();
617 assert_eq!(sliced.shape(), &[2, 3]);
618 let data: Vec<i32> = sliced.iter().copied().collect();
619 assert_eq!(data, vec![3, 4, 5, 6, 7, 8]);
620 }
621
622 #[test]
623 fn slice_axis_is_zero_copy() {
624 let arr =
625 Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
626 let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
627 unsafe {
628 assert_eq!(*sliced.as_ptr(), 2.0);
629 }
630 }
631
632 #[test]
633 fn slice_axis_zero_step_error() {
634 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
635 assert!(
636 arr.slice_axis(Axis(0), SliceSpec::with_step(0, 3, 0))
637 .is_err()
638 );
639 }
640
641 #[test]
646 fn slice_multi_2d() {
647 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 5]), (0..20).collect()).unwrap();
648 let sliced = arr
649 .slice_multi(&[SliceSpec::new(1, 3), SliceSpec::new(0, 4)])
650 .unwrap();
651 assert_eq!(sliced.shape(), &[2, 4]);
652 }
653
654 #[test]
655 fn slice_multi_wrong_count() {
656 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), (0..6).collect()).unwrap();
657 assert!(arr.slice_multi(&[SliceSpec::full()]).is_err());
658 }
659
660 #[test]
665 fn insert_axis_at_front() {
666 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
667 let expanded = arr.insert_axis(Axis(0)).unwrap();
668 assert_eq!(expanded.shape(), &[1, 3]);
669 }
670
671 #[test]
672 fn insert_axis_at_end() {
673 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
674 let expanded = arr.insert_axis(Axis(1)).unwrap();
675 assert_eq!(expanded.shape(), &[3, 1]);
676 }
677
678 #[test]
679 fn insert_axis_out_of_bounds() {
680 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
681 assert!(arr.insert_axis(Axis(3)).is_err());
682 }
683
684 #[test]
685 fn remove_axis_single() {
686 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
687 let squeezed = arr.remove_axis(Axis(0)).unwrap();
688 assert_eq!(squeezed.shape(), &[3]);
689 let data: Vec<f64> = squeezed.iter().copied().collect();
690 assert_eq!(data, vec![1.0, 2.0, 3.0]);
691 }
692
693 #[test]
694 fn remove_axis_not_one() {
695 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
696 assert!(arr.remove_axis(Axis(0)).is_err());
697 }
698
699 #[test]
700 fn remove_axis_out_of_bounds() {
701 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
702 assert!(arr.remove_axis(Axis(1)).is_err());
703 }
704
705 #[test]
710 fn flat_index_positive() {
711 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
712 assert_eq!(*arr.flat_index(0).unwrap(), 1);
713 assert_eq!(*arr.flat_index(3).unwrap(), 4);
714 assert_eq!(*arr.flat_index(5).unwrap(), 6);
715 }
716
717 #[test]
718 fn flat_index_negative() {
719 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
720 assert_eq!(*arr.flat_index(-1).unwrap(), 50);
721 assert_eq!(*arr.flat_index(-5).unwrap(), 10);
722 }
723
724 #[test]
725 fn flat_index_out_of_bounds() {
726 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
727 assert!(arr.flat_index(3).is_err());
728 assert!(arr.flat_index(-4).is_err());
729 }
730
731 #[test]
736 fn get_2d() {
737 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
738 assert_eq!(*arr.get(&[0, 0]).unwrap(), 0);
739 assert_eq!(*arr.get(&[1, 2]).unwrap(), 6);
740 assert_eq!(*arr.get(&[2, 3]).unwrap(), 11);
741 }
742
743 #[test]
744 fn get_negative_indices() {
745 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
746 assert_eq!(*arr.get(&[-1, -1]).unwrap(), 11);
747 assert_eq!(*arr.get(&[-3, 0]).unwrap(), 0);
748 }
749
750 #[test]
751 fn get_wrong_ndim() {
752 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), (0..6).collect()).unwrap();
753 assert!(arr.get(&[0]).is_err());
754 assert!(arr.get(&[0, 0, 0]).is_err());
755 }
756
757 #[test]
758 fn get_mut_modify() {
759 let mut arr =
760 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
761 *arr.get_mut(&[1, 2]).unwrap() = 99;
762 assert_eq!(*arr.get(&[1, 2]).unwrap(), 99);
763 }
764
765 #[test]
770 fn view_index_axis() {
771 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
772 let v = arr.view();
773 let row = v.index_axis(Axis(0), 1).unwrap();
774 let data: Vec<i32> = row.iter().copied().collect();
775 assert_eq!(data, vec![4, 5, 6, 7]);
776 }
777
778 #[test]
779 fn view_slice_axis() {
780 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
781 let v = arr.view();
782 let sliced = v.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
783 let data: Vec<i32> = sliced.iter().copied().collect();
784 assert_eq!(data, vec![20, 30, 40]);
785 }
786
787 #[test]
788 fn view_insert_remove_axis() {
789 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
790 let v = arr.view();
791 let expanded = v.insert_axis(Axis(0)).unwrap();
792 assert_eq!(expanded.shape(), &[1, 4]);
793 let squeezed = expanded.remove_axis(Axis(0)).unwrap();
794 assert_eq!(squeezed.shape(), &[4]);
795 }
796
797 #[test]
798 fn view_get() {
799 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
800 let v = arr.view();
801 assert_eq!(*v.get(&[1, 2]).unwrap(), 6);
802 }
803
804 #[test]
809 fn view_mut_slice_axis() {
810 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
811 {
812 let mut vm = arr.view_mut();
813 let mut sliced = vm.slice_axis_mut(Axis(0), SliceSpec::new(1, 3)).unwrap();
814 if let Some(s) = sliced.as_slice_mut() {
815 s[0] = 20;
816 s[1] = 30;
817 }
818 }
819 assert_eq!(arr.as_slice().unwrap(), &[1, 20, 30, 4, 5]);
820 }
821
822 #[test]
827 fn index_axis_3d() {
828 let arr = Array::<i32, Ix3>::from_vec(Ix3::new([2, 3, 4]), (0..24).collect()).unwrap();
829 let plane = arr.index_axis(Axis(0), 1).unwrap();
830 assert_eq!(plane.shape(), &[3, 4]);
831 assert_eq!(*plane.get(&[0, 0]).unwrap(), 12);
832 }
833
834 #[test]
842 fn index_axis_2d_returns_ix1() {
843 let arr = Array::<f64, crate::dimension::Ix2>::from_vec(
844 crate::dimension::Ix2::new([3, 4]),
845 (0..12).map(|i| i as f64).collect(),
846 )
847 .unwrap();
848 let row: ArrayView<'_, f64, crate::dimension::Ix1> = arr.index_axis(Axis(0), 1).unwrap();
850 assert_eq!(row.shape(), &[4]);
851 }
852
853 #[test]
854 fn index_axis_3d_returns_ix2() {
855 let arr = Array::<i32, Ix3>::from_vec(Ix3::new([2, 3, 4]), (0..24).collect()).unwrap();
856 let plane: ArrayView<'_, i32, crate::dimension::Ix2> = arr.index_axis(Axis(0), 0).unwrap();
857 assert_eq!(plane.shape(), &[3, 4]);
858 }
859
860 #[test]
861 fn insert_axis_ix2_returns_ix3() {
862 let arr = Array::<f64, crate::dimension::Ix2>::from_vec(
863 crate::dimension::Ix2::new([2, 3]),
864 (0..6).map(|i| i as f64).collect(),
865 )
866 .unwrap();
867 let expanded: ArrayView<'_, f64, Ix3> = arr.insert_axis(Axis(0)).unwrap();
868 assert_eq!(expanded.shape(), &[1, 2, 3]);
869 }
870
871 #[test]
872 fn remove_axis_ix3_returns_ix2() {
873 let arr =
874 Array::<f64, Ix3>::from_vec(Ix3::new([1, 3, 4]), (0..12).map(|i| i as f64).collect())
875 .unwrap();
876 let squeezed: ArrayView<'_, f64, crate::dimension::Ix2> = arr.remove_axis(Axis(0)).unwrap();
877 assert_eq!(squeezed.shape(), &[3, 4]);
878 }
879
880 #[test]
881 fn index_axis_chains_preserve_rank_at_each_step() {
882 let arr = Array::<i32, Ix3>::from_vec(Ix3::new([2, 3, 4]), (0..24).collect()).unwrap();
884 let plane: ArrayView<'_, i32, crate::dimension::Ix2> = arr.index_axis(Axis(0), 1).unwrap();
885 let row: ArrayView<'_, i32, crate::dimension::Ix1> = plane.index_axis(Axis(0), 1).unwrap();
886 let scalar: ArrayView<'_, i32, crate::dimension::Ix0> = row.index_axis(Axis(0), 2).unwrap();
887 assert_eq!(scalar.shape(), &[] as &[usize]);
888 }
889}