ndarray/impl_dyn.rs
1// Copyright 2018 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Methods for dynamic-dimensional arrays.
10use crate::imp_prelude::*;
11
12/// # Methods for Dynamic-Dimensional Arrays
13impl<A> LayoutRef<A, IxDyn>
14{
15 /// Insert new array axis of length 1 at `axis`, modifying the shape and
16 /// strides in-place.
17 ///
18 /// **Panics** if the axis is out of bounds.
19 ///
20 /// ```
21 /// use ndarray::{Axis, arr2, arr3};
22 ///
23 /// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn();
24 /// assert_eq!(a.shape(), &[2, 3]);
25 ///
26 /// a.insert_axis_inplace(Axis(1));
27 /// assert_eq!(a, arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn());
28 /// assert_eq!(a.shape(), &[2, 1, 3]);
29 /// ```
30 #[track_caller]
31 pub fn insert_axis_inplace(&mut self, axis: Axis)
32 {
33 assert!(axis.index() <= self.ndim());
34 self.0.dim = self._dim().insert_axis(axis);
35 self.0.strides = self._strides().insert_axis(axis);
36 }
37
38 /// Collapses the array to `index` along the axis and removes the axis,
39 /// modifying the shape and strides in-place.
40 ///
41 /// **Panics** if `axis` or `index` is out of bounds.
42 ///
43 /// ```
44 /// use ndarray::{Axis, arr1, arr2};
45 ///
46 /// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn();
47 /// assert_eq!(a.shape(), &[2, 3]);
48 ///
49 /// a.index_axis_inplace(Axis(1), 1);
50 /// assert_eq!(a, arr1(&[2, 5]).into_dyn());
51 /// assert_eq!(a.shape(), &[2]);
52 /// ```
53 #[track_caller]
54 pub fn index_axis_inplace(&mut self, axis: Axis, index: usize)
55 {
56 self.collapse_axis(axis, index);
57 self.0.dim = self._dim().remove_axis(axis);
58 self.0.strides = self._strides().remove_axis(axis);
59 }
60}
61
62impl<S: RawData> ArrayBase<S, IxDyn>
63{
64 /// Insert new array axis of length 1 at `axis`, modifying the shape and
65 /// strides in-place.
66 ///
67 /// **Panics** if the axis is out of bounds.
68 ///
69 /// ```
70 /// use ndarray::{Axis, arr2, arr3};
71 ///
72 /// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn();
73 /// assert_eq!(a.shape(), &[2, 3]);
74 ///
75 /// a.insert_axis_inplace(Axis(1));
76 /// assert_eq!(a, arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn());
77 /// assert_eq!(a.shape(), &[2, 1, 3]);
78 /// ```
79 #[track_caller]
80 pub fn insert_axis_inplace(&mut self, axis: Axis)
81 {
82 self.as_mut().insert_axis_inplace(axis)
83 }
84
85 /// Collapses the array to `index` along the axis and removes the axis,
86 /// modifying the shape and strides in-place.
87 ///
88 /// **Panics** if `axis` or `index` is out of bounds.
89 ///
90 /// ```
91 /// use ndarray::{Axis, arr1, arr2};
92 ///
93 /// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn();
94 /// assert_eq!(a.shape(), &[2, 3]);
95 ///
96 /// a.index_axis_inplace(Axis(1), 1);
97 /// assert_eq!(a, arr1(&[2, 5]).into_dyn());
98 /// assert_eq!(a.shape(), &[2]);
99 /// ```
100 #[track_caller]
101 pub fn index_axis_inplace(&mut self, axis: Axis, index: usize)
102 {
103 self.as_mut().index_axis_inplace(axis, index)
104 }
105}
106
107impl<A, S> ArrayBase<S, IxDyn>
108where S: Data<Elem = A>
109{
110 /// Remove axes of length 1 and return the modified array.
111 ///
112 /// If the array has more the one dimension, the result array will always
113 /// have at least one dimension, even if it has a length of 1.
114 ///
115 /// ```
116 /// use ndarray::{arr1, arr2, arr3};
117 ///
118 /// let a = arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn();
119 /// assert_eq!(a.shape(), &[2, 1, 3]);
120 /// let b = a.squeeze();
121 /// assert_eq!(b, arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn());
122 /// assert_eq!(b.shape(), &[2, 3]);
123 ///
124 /// let c = arr2(&[[1]]).into_dyn();
125 /// assert_eq!(c.shape(), &[1, 1]);
126 /// let d = c.squeeze();
127 /// assert_eq!(d, arr1(&[1]).into_dyn());
128 /// assert_eq!(d.shape(), &[1]);
129 /// ```
130 #[track_caller]
131 pub fn squeeze(self) -> Self
132 {
133 let mut out = self;
134 for axis in (0..out.shape().len()).rev() {
135 if out.shape()[axis] == 1 && out.shape().len() > 1 {
136 out = out.remove_axis(Axis(axis));
137 }
138 }
139 out
140 }
141}
142
143#[cfg(test)]
144mod tests
145{
146 use crate::{arr1, arr2, arr3};
147
148 #[test]
149 fn test_squeeze()
150 {
151 let a = arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn();
152 assert_eq!(a.shape(), &[2, 1, 3]);
153
154 let b = a.squeeze();
155 assert_eq!(b, arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn());
156 assert_eq!(b.shape(), &[2, 3]);
157
158 let c = arr2(&[[1]]).into_dyn();
159 assert_eq!(c.shape(), &[1, 1]);
160
161 let d = c.squeeze();
162 assert_eq!(d, arr1(&[1]).into_dyn());
163 assert_eq!(d.shape(), &[1]);
164 }
165}