1use std::borrow::Cow;
2
3use crate::Array;
4
5impl<'a, T: Clone, const D: usize> Array<'a, T, D> {
6 pub fn transpose(mut self) -> Array<'a, T, D> {
7 self.shape.reverse();
8 self.strides.reverse();
9 self.idx_maps.reverse();
10
11 self
12 }
13
14 pub fn t(&'a self) -> Array<'a, T, D> {
15 let mut shape = self.shape.clone();
16 let mut strides = self.strides.clone();
17 let mut idx_maps = self.idx_maps.clone();
18
19 shape.reverse();
20 strides.reverse();
21 idx_maps.reverse();
22
23 Array {
24 vec: Cow::from(&*self.vec),
25 shape,
26 strides,
27 idx_maps,
28 }
29 }
30
31 pub fn flip(&'a self, axis: usize) -> Array<'a, T, D> {
32 if axis >= D {
33 panic!("Axis out of bounds")
34 }
35
36 let mut idx_maps = self.idx_maps.clone();
37
38 let idx_map = &mut idx_maps[axis];
39
40 idx_map.append_b((self.shape[axis] - 1) as isize);
41 idx_map.m *= -1;
42
43 Array {
44 vec: Cow::from(&*self.vec),
45 shape: self.shape.clone(),
46 strides: self.strides.clone(),
47 idx_maps,
48 }
49 }
50
51 pub fn swap_axes(&'a self, axis0: usize, axis1: usize) -> Array<'a, T, D> {
52 if axis0 >= D || axis1 >= D {
53 panic!("Axis out of bounds")
54 }
55
56 let mut shape = self.shape.clone();
57 let mut strides = self.strides.clone();
58 let mut idx_maps = self.idx_maps.clone();
59
60 shape.swap(axis0, axis1);
61 strides.swap(axis0, axis1);
62 idx_maps.swap(axis0, axis1);
63
64 Array {
65 vec: Cow::from(&*self.vec),
66 shape,
67 strides,
68 idx_maps,
69 }
70 }
71
72 pub fn reshape<const S: usize>(&self, shape: [usize; S]) -> Array<'a, T, S> {
73 let vec = self.flat().cloned().collect();
76
77 Array::init(vec, shape)
78 }
79
80 pub fn flatten(&self) -> Array<'a, T, 1> {
81 let vec = self.flat().cloned().collect();
82
83 Array::init(vec, [self.vec.len()])
84 }
85
86 pub fn ravel(&self) -> Array<'a, T, 1> {
87 self.reshape([self.vec.len()])
88 }
89}
90
91impl<'a, T: Clone + Default, const D: usize> Array<'a, T, D> {
92 pub fn resize<const S: usize>(mut self, shape: [usize; S]) -> Array<'a, T, S> {
93 let new_size = shape.iter().product();
94 let old_size = self.shape().iter().product();
95
96 if new_size > old_size {
97 self.vec.to_mut().resize_with(new_size, || T::default());
98 } else {
99 self.vec.to_mut().truncate(new_size);
100 }
101
102 self.reshape(shape)
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[test]
111 fn reshape_array() {
112 let array = Array::init(vec![1, 2, 3, 4, 5, 6], [2, 3]);
116
117 let array = array.reshape([3, 2]);
122
123 assert_eq!(array[[0, 0]], 1);
124 assert_eq!(array[[0, 1]], 2);
125 assert_eq!(array[[1, 0]], 3);
126 assert_eq!(array[[1, 1]], 4);
127 assert_eq!(array[[2, 0]], 5);
128 assert_eq!(array[[2, 1]], 6);
129 }
130
131 #[test]
132 fn transpose() {
133 let array = Array::init(vec![1, 2, 3, 4, 5, 6], [2, 3]);
137
138 let array = array.transpose();
143
144 assert_eq!(array[[0, 0]], 1);
145 assert_eq!(array[[0, 1]], 4);
146 assert_eq!(array[[1, 0]], 2);
147 assert_eq!(array[[1, 1]], 5);
148 assert_eq!(array[[2, 0]], 3);
149 assert_eq!(array[[2, 1]], 6);
150 }
151
152 #[test]
153 fn transpose_the_reshape() {
154 let array = Array::init(vec![1, 2, 3, 4, 5, 6], [2, 3]);
158
159 let array = array.transpose();
164
165 let array = array.reshape([2, 3]);
169
170 assert_eq!(array[[0, 0]], 1);
171 assert_eq!(array[[0, 1]], 4);
172 assert_eq!(array[[0, 2]], 2);
173 assert_eq!(array[[1, 0]], 5);
174 assert_eq!(array[[1, 1]], 3);
175 assert_eq!(array[[1, 2]], 6);
176 }
177
178 #[test]
179 fn flip() {
180 let array = Array::init(vec![1, 2, 3, 4, 5, 6], [2, 3]);
184
185 let array = array.flip(0);
189
190 assert_eq!(
191 array.flat().copied().collect::<Vec<usize>>(),
192 vec![4, 5, 6, 1, 2, 3]
193 );
194 }
195
196 #[test]
197 fn swap_axis() {
198 let array = Array::init(vec![1, 2, 3], [1, 3]);
201
202 let swapped_array = array.swap_axes(0, 1);
203
204 assert_eq!(swapped_array[[0, 0]], 1);
205 assert_eq!(swapped_array[[1, 0]], 2);
206 assert_eq!(swapped_array[[2, 0]], 3);
207 }
208
209 #[test]
210 fn flatten() {
211 let array = Array::init(vec![1, 2, 3, 4, 5, 6], [2, 3]);
215
216 let flatten_array = array.flatten();
217
218 assert_eq!(
219 flatten_array.flat().copied().collect::<Vec<usize>>(),
220 vec![1, 2, 3, 4, 5, 6]
221 )
222 }
223}