1use crate::array::owned::Array;
10use crate::array::view::ArrayView;
11use crate::array::view_mut::ArrayViewMut;
12use crate::dimension::{Axis, Dimension, IxDyn};
13use crate::dtype::Element;
14use crate::error::{FerrayError, FerrayResult};
15
16use super::normalize_index;
17
18#[derive(Debug, Clone, Copy)]
25pub struct SliceSpec {
26 pub start: Option<isize>,
28 pub stop: Option<isize>,
30 pub step: Option<isize>,
32}
33
34impl SliceSpec {
35 #[must_use]
37 pub const fn full() -> Self {
38 Self {
39 start: None,
40 stop: None,
41 step: None,
42 }
43 }
44
45 #[must_use]
47 pub const fn new(start: isize, stop: isize) -> Self {
48 Self {
49 start: Some(start),
50 stop: Some(stop),
51 step: None,
52 }
53 }
54
55 #[must_use]
57 pub const fn with_step(start: isize, stop: isize, step: isize) -> Self {
58 Self {
59 start: Some(start),
60 stop: Some(stop),
61 step: Some(step),
62 }
63 }
64
65 fn validate(&self) -> FerrayResult<()> {
67 if self.step == Some(0) {
68 return Err(FerrayError::invalid_value("slice step cannot be zero"));
69 }
70 Ok(())
71 }
72
73 #[allow(clippy::wrong_self_convention)]
75 fn to_ndarray_slice(&self) -> ndarray::Slice {
76 ndarray::Slice::new(self.start.unwrap_or(0), self.stop, self.step.unwrap_or(1))
77 }
78
79 #[allow(dead_code, clippy::wrong_self_convention)]
81 pub(crate) fn to_ndarray_elem(&self) -> ndarray::SliceInfoElem {
82 ndarray::SliceInfoElem::Slice {
83 start: self.start.unwrap_or(0),
84 end: self.stop,
85 step: self.step.unwrap_or(1),
86 }
87 }
88}
89
90impl<T: Element, D: Dimension> Array<T, D> {
95 pub fn index_axis(&self, axis: Axis, index: isize) -> FerrayResult<ArrayView<'_, T, IxDyn>>
104 where
105 D::NdarrayDim: ndarray::RemoveAxis,
106 {
107 let ndim = self.ndim();
108 let ax = axis.index();
109 if ax >= ndim {
110 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
111 }
112 let size = self.shape()[ax];
113 let idx = normalize_index(index, size, ax)?;
114
115 let nd_axis = ndarray::Axis(ax);
116 let sub = self.inner.index_axis(nd_axis, idx);
117 let dyn_view = sub.into_dyn();
118 Ok(ArrayView::from_ndarray(dyn_view))
119 }
120
121 pub fn slice_axis(&self, axis: Axis, spec: SliceSpec) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
129 let ndim = self.ndim();
130 let ax = axis.index();
131 if ax >= ndim {
132 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
133 }
134 spec.validate()?;
135
136 let nd_axis = ndarray::Axis(ax);
137 let nd_slice = spec.to_ndarray_slice();
138 let sliced = self.inner.slice_axis(nd_axis, nd_slice);
139 let dyn_view = sliced.into_dyn();
140 Ok(ArrayView::from_ndarray(dyn_view))
141 }
142
143 pub fn slice_axis_mut(
148 &mut self,
149 axis: Axis,
150 spec: SliceSpec,
151 ) -> FerrayResult<ArrayViewMut<'_, T, IxDyn>> {
152 let ndim = self.ndim();
153 let ax = axis.index();
154 if ax >= ndim {
155 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
156 }
157 spec.validate()?;
158
159 let nd_axis = ndarray::Axis(ax);
160 let nd_slice = spec.to_ndarray_slice();
161 let sliced = self.inner.slice_axis_mut(nd_axis, nd_slice);
162 let dyn_view = sliced.into_dyn();
163 Ok(ArrayViewMut::from_ndarray(dyn_view))
164 }
165
166 pub fn slice_multi(&self, specs: &[SliceSpec]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
175 let ndim = self.ndim();
176 if specs.len() != ndim {
177 return Err(FerrayError::invalid_value(format!(
178 "expected {} slice specs, got {}",
179 ndim,
180 specs.len()
181 )));
182 }
183
184 for spec in specs {
185 spec.validate()?;
186 }
187
188 let mut result = self.inner.view().into_dyn();
190 for (ax, spec) in specs.iter().enumerate() {
191 let nd_axis = ndarray::Axis(ax);
192 let nd_slice = spec.to_ndarray_slice();
193 result = result.slice_axis_move(nd_axis, nd_slice).into_dyn();
194 }
195 Ok(ArrayView::from_ndarray(result))
196 }
197
198 pub fn insert_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
206 let ndim = self.ndim();
207 let ax = axis.index();
208 if ax > ndim {
209 return Err(FerrayError::axis_out_of_bounds(ax, ndim + 1));
210 }
211
212 let dyn_view = self.inner.view().into_dyn();
213 let expanded = dyn_view.insert_axis(ndarray::Axis(ax));
214 Ok(ArrayView::from_ndarray(expanded))
215 }
216
217 pub fn remove_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
226 let ndim = self.ndim();
227 let ax = axis.index();
228 if ax >= ndim {
229 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
230 }
231 if self.shape()[ax] != 1 {
232 return Err(FerrayError::invalid_value(format!(
233 "cannot remove axis {} with size {} (must be 1)",
234 ax,
235 self.shape()[ax]
236 )));
237 }
238
239 let dyn_view = self.inner.view().into_dyn();
241 let squeezed = dyn_view.index_axis_move(ndarray::Axis(ax), 0);
242 Ok(ArrayView::from_ndarray(squeezed))
243 }
244
245 pub fn flat_index(&self, index: isize) -> FerrayResult<&T> {
254 let size = self.size();
255 let idx = normalize_index(index, size, 0)?;
256
257 let shape = self.shape();
260 let strides = self.inner.strides();
261 let base_ptr = self.inner.as_ptr();
262 let ndim = shape.len();
263
264 let mut remaining = idx;
265 let mut offset: isize = 0;
266 for d in 0..ndim {
267 let dim_stride: usize = shape[d + 1..].iter().product::<usize>().max(1);
268 let coord = remaining / dim_stride;
269 remaining %= dim_stride;
270 offset += coord as isize * strides[d];
271 }
272
273 Ok(unsafe { &*base_ptr.offset(offset) })
277 }
278
279 pub fn get(&self, indices: &[isize]) -> FerrayResult<&T> {
287 let ndim = self.ndim();
288 if indices.len() != ndim {
289 return Err(FerrayError::invalid_value(format!(
290 "expected {} indices, got {}",
291 ndim,
292 indices.len()
293 )));
294 }
295
296 let shape = self.shape();
298 let strides = self.inner.strides();
299 let base_ptr = self.inner.as_ptr();
300
301 let mut offset: isize = 0;
302 for (ax, &idx) in indices.iter().enumerate() {
303 let pos = normalize_index(idx, shape[ax], ax)?;
304 offset += pos as isize * strides[ax];
305 }
306
307 Ok(unsafe { &*base_ptr.offset(offset) })
310 }
311
312 pub fn get_mut(&mut self, indices: &[isize]) -> FerrayResult<&mut T> {
317 let ndim = self.ndim();
318 if indices.len() != ndim {
319 return Err(FerrayError::invalid_value(format!(
320 "expected {} indices, got {}",
321 ndim,
322 indices.len()
323 )));
324 }
325
326 let shape = self.shape().to_vec();
327 let strides: Vec<isize> = self.inner.strides().to_vec();
328 let base_ptr = self.inner.as_mut_ptr();
329
330 let mut offset: isize = 0;
331 for (ax, &idx) in indices.iter().enumerate() {
332 let pos = normalize_index(idx, shape[ax], ax)?;
333 offset += pos as isize * strides[ax];
334 }
335
336 Ok(unsafe { &mut *base_ptr.offset(offset) })
339 }
340}
341
342impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
347 pub fn index_axis(&self, axis: Axis, index: isize) -> FerrayResult<ArrayView<'a, T, IxDyn>>
349 where
350 D::NdarrayDim: ndarray::RemoveAxis,
351 {
352 let ndim = self.ndim();
353 let ax = axis.index();
354 if ax >= ndim {
355 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
356 }
357 let size = self.shape()[ax];
358 let idx = normalize_index(index, size, ax)?;
359
360 let nd_axis = ndarray::Axis(ax);
361 let sub = self.inner.clone().index_axis_move(nd_axis, idx);
363 let dyn_view = sub.into_dyn();
364 Ok(ArrayView::from_ndarray(dyn_view))
365 }
366
367 pub fn slice_axis(&self, axis: Axis, spec: SliceSpec) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
369 let ndim = self.ndim();
370 let ax = axis.index();
371 if ax >= ndim {
372 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
373 }
374 spec.validate()?;
375
376 let nd_axis = ndarray::Axis(ax);
377 let nd_slice = spec.to_ndarray_slice();
378 let sliced = self.inner.clone().slice_axis_move(nd_axis, nd_slice);
380 let dyn_view = sliced.into_dyn();
381 Ok(ArrayView::from_ndarray(dyn_view))
382 }
383
384 pub fn insert_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
386 let ndim = self.ndim();
387 let ax = axis.index();
388 if ax > ndim {
389 return Err(FerrayError::axis_out_of_bounds(ax, ndim + 1));
390 }
391
392 let dyn_view = self.inner.clone().into_dyn();
393 let expanded = dyn_view.insert_axis(ndarray::Axis(ax));
394 Ok(ArrayView::from_ndarray(expanded))
395 }
396
397 pub fn remove_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
399 let ndim = self.ndim();
400 let ax = axis.index();
401 if ax >= ndim {
402 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
403 }
404 if self.shape()[ax] != 1 {
405 return Err(FerrayError::invalid_value(format!(
406 "cannot remove axis {} with size {} (must be 1)",
407 ax,
408 self.shape()[ax]
409 )));
410 }
411
412 let dyn_view = self.inner.clone().into_dyn();
413 let squeezed = dyn_view.index_axis_move(ndarray::Axis(ax), 0);
414 Ok(ArrayView::from_ndarray(squeezed))
415 }
416
417 pub fn get(&self, indices: &[isize]) -> FerrayResult<&'a T> {
419 let ndim = self.ndim();
420 if indices.len() != ndim {
421 return Err(FerrayError::invalid_value(format!(
422 "expected {} indices, got {}",
423 ndim,
424 indices.len()
425 )));
426 }
427
428 let shape = self.shape();
429 let strides = self.inner.strides();
430 let base_ptr = self.inner.as_ptr();
431
432 let mut offset: isize = 0;
433 for (ax, &idx) in indices.iter().enumerate() {
434 let pos = normalize_index(idx, shape[ax], ax)?;
435 offset += pos as isize * strides[ax];
436 }
437
438 Ok(unsafe { &*base_ptr.offset(offset) })
440 }
441}
442
443impl<T: Element, D: Dimension> ArrayViewMut<'_, T, D> {
448 pub fn slice_axis_mut(
450 &mut self,
451 axis: Axis,
452 spec: SliceSpec,
453 ) -> FerrayResult<ArrayViewMut<'_, T, IxDyn>> {
454 let ndim = self.ndim();
455 let ax = axis.index();
456 if ax >= ndim {
457 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
458 }
459 spec.validate()?;
460
461 let nd_axis = ndarray::Axis(ax);
462 let nd_slice = spec.to_ndarray_slice();
463 let sliced = self.inner.slice_axis_mut(nd_axis, nd_slice);
464 let dyn_view = sliced.into_dyn();
465 Ok(ArrayViewMut::from_ndarray(dyn_view))
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use crate::dimension::{Ix1, Ix2, Ix3};
473
474 #[test]
479 fn normalize_positive_in_bounds() {
480 assert_eq!(normalize_index(2, 5, 0).unwrap(), 2);
481 }
482
483 #[test]
484 fn normalize_negative() {
485 assert_eq!(normalize_index(-1, 5, 0).unwrap(), 4);
486 assert_eq!(normalize_index(-5, 5, 0).unwrap(), 0);
487 }
488
489 #[test]
490 fn normalize_out_of_bounds() {
491 assert!(normalize_index(5, 5, 0).is_err());
492 assert!(normalize_index(-6, 5, 0).is_err());
493 }
494
495 #[test]
500 fn index_axis_row() {
501 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
502 let row = arr.index_axis(Axis(0), 1).unwrap();
503 assert_eq!(row.shape(), &[4]);
504 let data: Vec<i32> = row.iter().copied().collect();
505 assert_eq!(data, vec![4, 5, 6, 7]);
506 }
507
508 #[test]
509 fn index_axis_column() {
510 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
511 let col = arr.index_axis(Axis(1), 2).unwrap();
512 assert_eq!(col.shape(), &[3]);
513 let data: Vec<i32> = col.iter().copied().collect();
514 assert_eq!(data, vec![2, 6, 10]);
515 }
516
517 #[test]
518 fn index_axis_negative() {
519 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
520 let row = arr.index_axis(Axis(0), -1).unwrap();
521 let data: Vec<i32> = row.iter().copied().collect();
522 assert_eq!(data, vec![8, 9, 10, 11]);
523 }
524
525 #[test]
526 fn index_axis_out_of_bounds() {
527 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
528 assert!(arr.index_axis(Axis(0), 3).is_err());
529 assert!(arr.index_axis(Axis(2), 0).is_err());
530 }
531
532 #[test]
533 fn index_axis_is_zero_copy() {
534 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
535 .unwrap();
536 let row = arr.index_axis(Axis(0), 0).unwrap();
537 assert_eq!(row.as_ptr(), arr.as_ptr());
538 }
539
540 #[test]
545 fn slice_axis_basic() {
546 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
547 let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
548 assert_eq!(sliced.shape(), &[3]);
549 let data: Vec<i32> = sliced.iter().copied().collect();
550 assert_eq!(data, vec![20, 30, 40]);
551 }
552
553 #[test]
554 fn slice_axis_step() {
555 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![0, 1, 2, 3, 4, 5]).unwrap();
556 let sliced = arr
557 .slice_axis(Axis(0), SliceSpec::with_step(0, 6, 2))
558 .unwrap();
559 assert_eq!(sliced.shape(), &[3]);
560 let data: Vec<i32> = sliced.iter().copied().collect();
561 assert_eq!(data, vec![0, 2, 4]);
562 }
563
564 #[test]
565 fn slice_axis_negative_step() {
566 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
567 let spec = SliceSpec {
571 start: None,
572 stop: None,
573 step: Some(-1),
574 };
575 let sliced = arr.slice_axis(Axis(0), spec).unwrap();
576 let data: Vec<i32> = sliced.iter().copied().collect();
577 assert_eq!(data, vec![4, 3, 2, 1, 0]);
578 }
579
580 #[test]
581 fn slice_axis_negative_step_partial() {
582 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
583 let sliced = arr
585 .slice_axis(Axis(0), SliceSpec::with_step(1, 4, -1))
586 .unwrap();
587 let data: Vec<i32> = sliced.iter().copied().collect();
588 assert_eq!(data, vec![3, 2, 1]);
589 }
590
591 #[test]
592 fn slice_axis_full() {
593 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
594 let sliced = arr.slice_axis(Axis(0), SliceSpec::full()).unwrap();
595 let data: Vec<i32> = sliced.iter().copied().collect();
596 assert_eq!(data, vec![1, 2, 3]);
597 }
598
599 #[test]
600 fn slice_axis_2d_rows() {
601 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), (0..12).collect()).unwrap();
602 let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 3)).unwrap();
603 assert_eq!(sliced.shape(), &[2, 3]);
604 let data: Vec<i32> = sliced.iter().copied().collect();
605 assert_eq!(data, vec![3, 4, 5, 6, 7, 8]);
606 }
607
608 #[test]
609 fn slice_axis_is_zero_copy() {
610 let arr =
611 Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
612 let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
613 unsafe {
614 assert_eq!(*sliced.as_ptr(), 2.0);
615 }
616 }
617
618 #[test]
619 fn slice_axis_zero_step_error() {
620 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
621 assert!(
622 arr.slice_axis(Axis(0), SliceSpec::with_step(0, 3, 0))
623 .is_err()
624 );
625 }
626
627 #[test]
632 fn slice_multi_2d() {
633 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 5]), (0..20).collect()).unwrap();
634 let sliced = arr
635 .slice_multi(&[SliceSpec::new(1, 3), SliceSpec::new(0, 4)])
636 .unwrap();
637 assert_eq!(sliced.shape(), &[2, 4]);
638 }
639
640 #[test]
641 fn slice_multi_wrong_count() {
642 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), (0..6).collect()).unwrap();
643 assert!(arr.slice_multi(&[SliceSpec::full()]).is_err());
644 }
645
646 #[test]
651 fn insert_axis_at_front() {
652 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
653 let expanded = arr.insert_axis(Axis(0)).unwrap();
654 assert_eq!(expanded.shape(), &[1, 3]);
655 }
656
657 #[test]
658 fn insert_axis_at_end() {
659 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
660 let expanded = arr.insert_axis(Axis(1)).unwrap();
661 assert_eq!(expanded.shape(), &[3, 1]);
662 }
663
664 #[test]
665 fn insert_axis_out_of_bounds() {
666 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
667 assert!(arr.insert_axis(Axis(3)).is_err());
668 }
669
670 #[test]
671 fn remove_axis_single() {
672 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
673 let squeezed = arr.remove_axis(Axis(0)).unwrap();
674 assert_eq!(squeezed.shape(), &[3]);
675 let data: Vec<f64> = squeezed.iter().copied().collect();
676 assert_eq!(data, vec![1.0, 2.0, 3.0]);
677 }
678
679 #[test]
680 fn remove_axis_not_one() {
681 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
682 assert!(arr.remove_axis(Axis(0)).is_err());
683 }
684
685 #[test]
686 fn remove_axis_out_of_bounds() {
687 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
688 assert!(arr.remove_axis(Axis(1)).is_err());
689 }
690
691 #[test]
696 fn flat_index_positive() {
697 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
698 assert_eq!(*arr.flat_index(0).unwrap(), 1);
699 assert_eq!(*arr.flat_index(3).unwrap(), 4);
700 assert_eq!(*arr.flat_index(5).unwrap(), 6);
701 }
702
703 #[test]
704 fn flat_index_negative() {
705 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
706 assert_eq!(*arr.flat_index(-1).unwrap(), 50);
707 assert_eq!(*arr.flat_index(-5).unwrap(), 10);
708 }
709
710 #[test]
711 fn flat_index_out_of_bounds() {
712 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
713 assert!(arr.flat_index(3).is_err());
714 assert!(arr.flat_index(-4).is_err());
715 }
716
717 #[test]
722 fn get_2d() {
723 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
724 assert_eq!(*arr.get(&[0, 0]).unwrap(), 0);
725 assert_eq!(*arr.get(&[1, 2]).unwrap(), 6);
726 assert_eq!(*arr.get(&[2, 3]).unwrap(), 11);
727 }
728
729 #[test]
730 fn get_negative_indices() {
731 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
732 assert_eq!(*arr.get(&[-1, -1]).unwrap(), 11);
733 assert_eq!(*arr.get(&[-3, 0]).unwrap(), 0);
734 }
735
736 #[test]
737 fn get_wrong_ndim() {
738 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), (0..6).collect()).unwrap();
739 assert!(arr.get(&[0]).is_err());
740 assert!(arr.get(&[0, 0, 0]).is_err());
741 }
742
743 #[test]
744 fn get_mut_modify() {
745 let mut arr =
746 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
747 *arr.get_mut(&[1, 2]).unwrap() = 99;
748 assert_eq!(*arr.get(&[1, 2]).unwrap(), 99);
749 }
750
751 #[test]
756 fn view_index_axis() {
757 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
758 let v = arr.view();
759 let row = v.index_axis(Axis(0), 1).unwrap();
760 let data: Vec<i32> = row.iter().copied().collect();
761 assert_eq!(data, vec![4, 5, 6, 7]);
762 }
763
764 #[test]
765 fn view_slice_axis() {
766 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
767 let v = arr.view();
768 let sliced = v.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
769 let data: Vec<i32> = sliced.iter().copied().collect();
770 assert_eq!(data, vec![20, 30, 40]);
771 }
772
773 #[test]
774 fn view_insert_remove_axis() {
775 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
776 let v = arr.view();
777 let expanded = v.insert_axis(Axis(0)).unwrap();
778 assert_eq!(expanded.shape(), &[1, 4]);
779 let squeezed = expanded.remove_axis(Axis(0)).unwrap();
780 assert_eq!(squeezed.shape(), &[4]);
781 }
782
783 #[test]
784 fn view_get() {
785 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
786 let v = arr.view();
787 assert_eq!(*v.get(&[1, 2]).unwrap(), 6);
788 }
789
790 #[test]
795 fn view_mut_slice_axis() {
796 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
797 {
798 let mut vm = arr.view_mut();
799 let mut sliced = vm.slice_axis_mut(Axis(0), SliceSpec::new(1, 3)).unwrap();
800 if let Some(s) = sliced.as_slice_mut() {
801 s[0] = 20;
802 s[1] = 30;
803 }
804 }
805 assert_eq!(arr.as_slice().unwrap(), &[1, 20, 30, 4, 5]);
806 }
807
808 #[test]
813 fn index_axis_3d() {
814 let arr = Array::<i32, Ix3>::from_vec(Ix3::new([2, 3, 4]), (0..24).collect()).unwrap();
815 let plane = arr.index_axis(Axis(0), 1).unwrap();
816 assert_eq!(plane.shape(), &[3, 4]);
817 assert_eq!(*plane.get(&[0, 0]).unwrap(), 12);
818 }
819}