nd_array/ndarray/array/
transformation.rs

1use std::borrow::Cow;
2
3use crate::Array;
4
5impl<'a, T: Clone, const D: usize> Array<'a, T, D> {
6    pub fn transpose(mut self) -> Array<'a, T, D> {
7        self.shape.reverse();
8        self.strides.reverse();
9        self.idx_maps.reverse();
10
11        self
12    }
13
14    pub fn t(&'a self) -> Array<'a, T, D> {
15        let mut shape = self.shape.clone();
16        let mut strides = self.strides.clone();
17        let mut idx_maps = self.idx_maps.clone();
18
19        shape.reverse();
20        strides.reverse();
21        idx_maps.reverse();
22
23        Array {
24            vec: Cow::from(&*self.vec),
25            shape,
26            strides,
27            idx_maps,
28        }
29    }
30
31    pub fn flip(&'a self, axis: usize) -> Array<'a, T, D> {
32        if axis >= D {
33            panic!("Axis out of bounds")
34        }
35
36        let mut idx_maps = self.idx_maps.clone();
37
38        let idx_map = &mut idx_maps[axis];
39
40        idx_map.append_b((self.shape[axis] - 1) as isize);
41        idx_map.m *= -1;
42
43        Array {
44            vec: Cow::from(&*self.vec),
45            shape: self.shape.clone(),
46            strides: self.strides.clone(),
47            idx_maps,
48        }
49    }
50
51    pub fn swap_axes(&'a self, axis0: usize, axis1: usize) -> Array<'a, T, D> {
52        if axis0 >= D || axis1 >= D {
53            panic!("Axis out of bounds")
54        }
55
56        let mut shape = self.shape.clone();
57        let mut strides = self.strides.clone();
58        let mut idx_maps = self.idx_maps.clone();
59
60        shape.swap(axis0, axis1);
61        strides.swap(axis0, axis1);
62        idx_maps.swap(axis0, axis1);
63
64        Array {
65            vec: Cow::from(&*self.vec),
66            shape,
67            strides,
68            idx_maps,
69        }
70    }
71
72    pub fn reshape<const S: usize>(&self, shape: [usize; S]) -> Array<'a, T, S> {
73        // TODO: Check wether cloning is necessary
74
75        let vec = self.flat().cloned().collect();
76
77        Array::init(vec, shape)
78    }
79
80    pub fn flatten(&self) -> Array<'a, T, 1> {
81        let vec = self.flat().cloned().collect();
82
83        Array::init(vec, [self.vec.len()])
84    }
85
86    pub fn ravel(&self) -> Array<'a, T, 1> {
87        self.reshape([self.vec.len()])
88    }
89}
90
91impl<'a, T: Clone + Default, const D: usize> Array<'a, T, D> {
92    pub fn resize<const S: usize>(mut self, shape: [usize; S]) -> Array<'a, T, S> {
93        let new_size = shape.iter().product();
94        let old_size = self.shape().iter().product();
95
96        if new_size > old_size {
97            self.vec.to_mut().resize_with(new_size, || T::default());
98        } else {
99            self.vec.to_mut().truncate(new_size);
100        }
101
102        self.reshape(shape)
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn reshape_array() {
112        // 2-D array:
113        // 1 2 3
114        // 4 5 6
115        let array = Array::init(vec![1, 2, 3, 4, 5, 6], [2, 3]);
116
117        // reshape it to the 3x2 2-D array:
118        // 1 2
119        // 3 4
120        // 5 6
121        let array = array.reshape([3, 2]);
122
123        assert_eq!(array[[0, 0]], 1);
124        assert_eq!(array[[0, 1]], 2);
125        assert_eq!(array[[1, 0]], 3);
126        assert_eq!(array[[1, 1]], 4);
127        assert_eq!(array[[2, 0]], 5);
128        assert_eq!(array[[2, 1]], 6);
129    }
130
131    #[test]
132    fn transpose() {
133        // 2-D array:
134        // 1 2 3
135        // 4 5 6
136        let array = Array::init(vec![1, 2, 3, 4, 5, 6], [2, 3]);
137
138        // tranpose the array to:
139        // 1 4
140        // 2 5
141        // 3 6
142        let array = array.transpose();
143
144        assert_eq!(array[[0, 0]], 1);
145        assert_eq!(array[[0, 1]], 4);
146        assert_eq!(array[[1, 0]], 2);
147        assert_eq!(array[[1, 1]], 5);
148        assert_eq!(array[[2, 0]], 3);
149        assert_eq!(array[[2, 1]], 6);
150    }
151
152    #[test]
153    fn transpose_the_reshape() {
154        // 2-D array:
155        // 1 2 3
156        // 4 5 6
157        let array = Array::init(vec![1, 2, 3, 4, 5, 6], [2, 3]);
158
159        // tranpose the array to:
160        // 1 4
161        // 2 5
162        // 3 6
163        let array = array.transpose();
164
165        // reshape the array to a 2x3 2-D array:
166        // 1 4 2
167        // 5 3 6
168        let array = array.reshape([2, 3]);
169
170        assert_eq!(array[[0, 0]], 1);
171        assert_eq!(array[[0, 1]], 4);
172        assert_eq!(array[[0, 2]], 2);
173        assert_eq!(array[[1, 0]], 5);
174        assert_eq!(array[[1, 1]], 3);
175        assert_eq!(array[[1, 2]], 6);
176    }
177
178    #[test]
179    fn flip() {
180        // 2-D array:
181        // 1 2 3
182        // 4 5 6
183        let array = Array::init(vec![1, 2, 3, 4, 5, 6], [2, 3]);
184
185        // flip axis = 0
186        // 4 5 6
187        // 1 2 3
188        let array = array.flip(0);
189
190        assert_eq!(
191            array.flat().copied().collect::<Vec<usize>>(),
192            vec![4, 5, 6, 1, 2, 3]
193        );
194    }
195
196    #[test]
197    fn swap_axis() {
198        // 2-D array:
199        // 1 2 3
200        let array = Array::init(vec![1, 2, 3], [1, 3]);
201
202        let swapped_array = array.swap_axes(0, 1);
203
204        assert_eq!(swapped_array[[0, 0]], 1);
205        assert_eq!(swapped_array[[1, 0]], 2);
206        assert_eq!(swapped_array[[2, 0]], 3);
207    }
208
209    #[test]
210    fn flatten() {
211        // 2-D array:
212        // 1 2 3
213        // 4 5 6
214        let array = Array::init(vec![1, 2, 3, 4, 5, 6], [2, 3]);
215
216        let flatten_array = array.flatten();
217
218        assert_eq!(
219            flatten_array.flat().copied().collect::<Vec<usize>>(),
220            vec![1, 2, 3, 4, 5, 6]
221        )
222    }
223}