1mod access;
2mod calc;
3mod iter;
4mod ops;
5mod transformation;
6
7use std::borrow::Cow;
8
9use num_traits::{One, Zero};
10
11#[derive(Debug, Clone, Copy)]
12struct IdxMap {
13 m: isize,
14 b: isize,
15}
16
17impl IdxMap {
18 fn init() -> Self {
19 IdxMap { m: 1, b: 0 }
20 }
21
22 fn map(&self, idx: usize) -> usize {
23 (self.m * (idx as isize) + self.b) as usize
24 }
25
26 fn append_b(&mut self, b: isize) {
27 self.b += self.m * b;
28 }
29}
30
31pub struct Array<'a, T: Clone, const D: usize> {
32 vec: Cow<'a, [T]>,
33 shape: [usize; D],
34 strides: [usize; D],
35 idx_maps: [IdxMap; D],
36}
37
38impl<'a, T: Clone, const D: usize> Array<'a, T, D> {
39 pub fn init(vec: Vec<T>, shape: [usize; D]) -> Self {
40 let elem_count: usize = shape.iter().product();
41
42 if elem_count != vec.len() {
43 panic!(
44 "Number of elements in vec is not equal to dimension specification: {} != {}",
45 vec.len(),
46 elem_count
47 );
48 }
49
50 let mut strides = [0; D];
51 for axis in 0..D {
52 strides[axis] = shape[axis + 1..].iter().fold(1, |acc, v| acc * v);
53 }
54
55 Array {
56 vec: Cow::from(vec),
57 shape,
58 strides,
59 idx_maps: [IdxMap::init(); D],
60 }
61 }
62
63 pub fn shape(&self) -> &[usize; D] {
64 &self.shape
65 }
66
67 pub fn strides(&self) -> &[usize; D] {
68 &self.strides
69 }
70
71 pub fn full(val: T, shape: [usize; D]) -> Array<'a, T, D> {
72 Array::init(vec![val; shape.iter().product()], shape)
73 }
74
75 pub fn full_like<'b, U: Clone>(val: T, array: &Array<'b, U, D>) -> Array<'a, T, D> {
76 Array::full(val, array.shape().clone())
77 }
78}
79
80impl<'a, T: Clone> Array<'a, T, 1> {
81 pub fn arange<I: Iterator<Item = T>>(range: I) -> Array<'a, T, 1> {
82 let vec: Vec<T> = range.collect();
83 let len = vec.len();
84
85 Array::init(vec, [len])
86 }
87}
88
89impl<'a, T: Clone + Zero, const D: usize> Array<'a, T, D> {
90 pub fn zeros(shape: [usize; D]) -> Self {
91 Array::init(vec![T::zero(); shape.iter().product()], shape)
92 }
93
94 pub fn zeros_like<'b, U: Clone>(array: &Array<'b, U, D>) -> Array<'a, T, D> {
95 Array::zeros(array.shape().clone())
96 }
97}
98
99impl<'a, T: Clone + One, const D: usize> Array<'a, T, D> {
100 pub fn ones(shape: [usize; D]) -> Self {
101 Array::init(vec![T::one(); shape.iter().product()], shape)
102 }
103
104 pub fn ones_like<'b, U: Clone>(array: &Array<'b, U, D>) -> Array<'a, T, D> {
105 Array::ones(array.shape().clone())
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[test]
114 fn arange() {
115 let array = Array::arange(0..10);
116
117 assert_eq!(
118 array.flat().copied().collect::<Vec<usize>>(),
119 vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
120 )
121 }
122
123 #[test]
124 fn zeros() {
125 let array = Array::zeros([2, 4]);
126
127 assert_eq!(
128 array.flat().copied().collect::<Vec<usize>>(),
129 vec![0, 0, 0, 0, 0, 0, 0, 0]
130 )
131 }
132
133 #[test]
134 fn zeros_like() {
135 let array = Array::arange(0..8).reshape([2, 4]);
136
137 let zeros_like = Array::zeros_like(&array);
138
139 assert_eq!(
140 zeros_like.flat().copied().collect::<Vec<usize>>(),
141 vec![0, 0, 0, 0, 0, 0, 0, 0]
142 )
143 }
144
145 #[test]
146 fn ones() {
147 let array = Array::ones([2, 4]);
148
149 assert_eq!(
150 array.flat().copied().collect::<Vec<usize>>(),
151 vec![1, 1, 1, 1, 1, 1, 1, 1]
152 )
153 }
154
155 #[test]
156 fn ones_like() {
157 let array = Array::arange(0..8).reshape([2, 4]);
158
159 let ones_like = Array::ones_like(&array);
160
161 assert_eq!(
162 ones_like.flat().copied().collect::<Vec<usize>>(),
163 vec![1, 1, 1, 1, 1, 1, 1, 1]
164 )
165 }
166
167 #[test]
168 fn full() {
169 let array = Array::full(10, [2, 4]);
170
171 assert_eq!(
172 array.flat().copied().collect::<Vec<usize>>(),
173 vec![10, 10, 10, 10, 10, 10, 10, 10]
174 )
175 }
176
177 #[test]
178 fn full_like() {
179 let array = Array::arange(0..8).reshape([2, 4]);
180
181 let full_like = Array::full_like(10, &array);
182
183 assert_eq!(
184 full_like.flat().copied().collect::<Vec<usize>>(),
185 vec![10, 10, 10, 10, 10, 10, 10, 10]
186 )
187 }
188}