nd_array/ndarray/array/
mod.rs

1mod access;
2mod calc;
3mod iter;
4mod ops;
5mod transformation;
6
7use std::borrow::Cow;
8
9use num_traits::{One, Zero};
10
11#[derive(Debug, Clone, Copy)]
12struct IdxMap {
13    m: isize,
14    b: isize,
15}
16
17impl IdxMap {
18    fn init() -> Self {
19        IdxMap { m: 1, b: 0 }
20    }
21
22    fn map(&self, idx: usize) -> usize {
23        (self.m * (idx as isize) + self.b) as usize
24    }
25
26    fn append_b(&mut self, b: isize) {
27        self.b += self.m * b;
28    }
29}
30
31pub struct Array<'a, T: Clone, const D: usize> {
32    vec: Cow<'a, [T]>,
33    shape: [usize; D],
34    strides: [usize; D],
35    idx_maps: [IdxMap; D],
36}
37
38impl<'a, T: Clone, const D: usize> Array<'a, T, D> {
39    pub fn init(vec: Vec<T>, shape: [usize; D]) -> Self {
40        let elem_count: usize = shape.iter().product();
41
42        if elem_count != vec.len() {
43            panic!(
44                "Number of elements in vec is not equal to dimension specification: {} != {}",
45                vec.len(),
46                elem_count
47            );
48        }
49
50        let mut strides = [0; D];
51        for axis in 0..D {
52            strides[axis] = shape[axis + 1..].iter().fold(1, |acc, v| acc * v);
53        }
54
55        Array {
56            vec: Cow::from(vec),
57            shape,
58            strides,
59            idx_maps: [IdxMap::init(); D],
60        }
61    }
62
63    pub fn shape(&self) -> &[usize; D] {
64        &self.shape
65    }
66
67    pub fn strides(&self) -> &[usize; D] {
68        &self.strides
69    }
70
71    pub fn full(val: T, shape: [usize; D]) -> Array<'a, T, D> {
72        Array::init(vec![val; shape.iter().product()], shape)
73    }
74
75    pub fn full_like<'b, U: Clone>(val: T, array: &Array<'b, U, D>) -> Array<'a, T, D> {
76        Array::full(val, array.shape().clone())
77    }
78}
79
80impl<'a, T: Clone> Array<'a, T, 1> {
81    pub fn arange<I: Iterator<Item = T>>(range: I) -> Array<'a, T, 1> {
82        let vec: Vec<T> = range.collect();
83        let len = vec.len();
84
85        Array::init(vec, [len])
86    }
87}
88
89impl<'a, T: Clone + Zero, const D: usize> Array<'a, T, D> {
90    pub fn zeros(shape: [usize; D]) -> Self {
91        Array::init(vec![T::zero(); shape.iter().product()], shape)
92    }
93
94    pub fn zeros_like<'b, U: Clone>(array: &Array<'b, U, D>) -> Array<'a, T, D> {
95        Array::zeros(array.shape().clone())
96    }
97}
98
99impl<'a, T: Clone + One, const D: usize> Array<'a, T, D> {
100    pub fn ones(shape: [usize; D]) -> Self {
101        Array::init(vec![T::one(); shape.iter().product()], shape)
102    }
103
104    pub fn ones_like<'b, U: Clone>(array: &Array<'b, U, D>) -> Array<'a, T, D> {
105        Array::ones(array.shape().clone())
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn arange() {
115        let array = Array::arange(0..10);
116
117        assert_eq!(
118            array.flat().copied().collect::<Vec<usize>>(),
119            vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
120        )
121    }
122
123    #[test]
124    fn zeros() {
125        let array = Array::zeros([2, 4]);
126
127        assert_eq!(
128            array.flat().copied().collect::<Vec<usize>>(),
129            vec![0, 0, 0, 0, 0, 0, 0, 0]
130        )
131    }
132
133    #[test]
134    fn zeros_like() {
135        let array = Array::arange(0..8).reshape([2, 4]);
136
137        let zeros_like = Array::zeros_like(&array);
138
139        assert_eq!(
140            zeros_like.flat().copied().collect::<Vec<usize>>(),
141            vec![0, 0, 0, 0, 0, 0, 0, 0]
142        )
143    }
144
145    #[test]
146    fn ones() {
147        let array = Array::ones([2, 4]);
148
149        assert_eq!(
150            array.flat().copied().collect::<Vec<usize>>(),
151            vec![1, 1, 1, 1, 1, 1, 1, 1]
152        )
153    }
154
155    #[test]
156    fn ones_like() {
157        let array = Array::arange(0..8).reshape([2, 4]);
158
159        let ones_like = Array::ones_like(&array);
160
161        assert_eq!(
162            ones_like.flat().copied().collect::<Vec<usize>>(),
163            vec![1, 1, 1, 1, 1, 1, 1, 1]
164        )
165    }
166
167    #[test]
168    fn full() {
169        let array = Array::full(10, [2, 4]);
170
171        assert_eq!(
172            array.flat().copied().collect::<Vec<usize>>(),
173            vec![10, 10, 10, 10, 10, 10, 10, 10]
174        )
175    }
176
177    #[test]
178    fn full_like() {
179        let array = Array::arange(0..8).reshape([2, 4]);
180
181        let full_like = Array::full_like(10, &array);
182
183        assert_eq!(
184            full_like.flat().copied().collect::<Vec<usize>>(),
185            vec![10, 10, 10, 10, 10, 10, 10, 10]
186        )
187    }
188}