1use rayon::prelude::*;
8use rayon::slice::ParallelSliceMut;
9
10#[derive(Debug, Clone)]
12pub struct Array<T> {
13 data: Vec<T>,
14 shape: Vec<usize>,
15}
16
17impl<T> Array<T> {
18 pub fn from_vec(data: Vec<T>, shape: Vec<usize>) -> Self {
19 let expected_len: usize = shape.iter().copied().product();
20 assert_eq!(data.len(), expected_len, "Data length must match shape");
21 Array { data, shape }
22 }
23
24 pub fn shape(&self) -> &[usize] {
25 &self.shape
26 }
27}
28
29impl Array<f64> {
30 pub fn add_simd(&self, other: &Array<f64>) -> Array<f64> {
33 assert_eq!(self.shape, other.shape, "Arrays must have the same shape");
34
35 let mut result_data = Vec::with_capacity(self.data.len());
36 let chunks = self.data.len() / 4;
37
38 for i in 0..chunks {
40 let start = i * 4;
41 result_data.push(self.data[start] + other.data[start]);
42 result_data.push(self.data[start + 1] + other.data[start + 1]);
43 result_data.push(self.data[start + 2] + other.data[start + 2]);
44 result_data.push(self.data[start + 3] + other.data[start + 3]);
45 }
46
47 for i in chunks * 4..self.data.len() {
49 result_data.push(self.data[i] + other.data[i]);
50 }
51
52 Array::from_vec(result_data, self.shape.clone())
53 }
54 pub fn dot_1d_simd(&self, other: &Array<f64>) -> f64 {
57 assert_eq!(self.shape.len(), 1, "Arrays must be 1D");
58 assert_eq!(
59 self.shape[0], other.shape[0],
60 "Arrays must have same length"
61 );
62
63 let mut sum = 0.0;
64 let chunks = self.data.len() / 4;
65
66 for i in 0..chunks {
68 let start = i * 4;
69 sum += self.data[start] * other.data[start];
70 sum += self.data[start + 1] * other.data[start + 1];
71 sum += self.data[start + 2] * other.data[start + 2];
72 sum += self.data[start + 3] * other.data[start + 3];
73 }
74
75 for i in chunks * 4..self.data.len() {
77 sum += self.data[i] * other.data[i];
78 }
79
80 sum
81 }
82
83 pub fn matmul_parallel(&self, other: &Array<f64>) -> Array<f64> {
85 assert_eq!(self.shape.len(), 2);
86 assert_eq!(other.shape.len(), 2);
87 assert_eq!(self.shape[1], other.shape[0], "Inner dimensions must match");
88
89 let (m, k) = (self.shape[0], self.shape[1]);
90 let n = other.shape[1];
91 let mut result = vec![0.0; m * n];
92
93 result.par_chunks_mut(n).enumerate().for_each(|(i, row)| {
94 for (j, item) in row.iter_mut().enumerate().take(n) {
95 let mut sum = 0.0;
96 for kk in 0..k {
97 sum += self.data[i * k + kk] * other.data[kk * n + j];
98 }
99 *item = sum;
100 }
101 });
102
103 Array::from_vec(result, vec![m, n])
104 }
105
106 pub fn matmul_blocked(&self, other: &Array<f64>, block_size: usize) -> Array<f64> {
108 assert_eq!(self.shape.len(), 2);
109 assert_eq!(other.shape.len(), 2);
110 assert_eq!(self.shape[1], other.shape[0], "Inner dimensions must match");
111
112 let (m, k) = (self.shape[0], self.shape[1]);
113 let n = other.shape[1];
114 let mut result = vec![0.0; m * n];
115
116 for i0 in (0..m).step_by(block_size) {
117 for j0 in (0..n).step_by(block_size) {
118 for k0 in (0..k).step_by(block_size) {
119 let i_max = (i0 + block_size).min(m);
120 let j_max = (j0 + block_size).min(n);
121 let k_max = (k0 + block_size).min(k);
122
123 for i in i0..i_max {
124 for j in j0..j_max {
125 let mut sum = 0.0;
126 for kk in k0..k_max {
127 sum += self.data[i * k + kk] * other.data[kk * n + j];
128 }
129 result[i * n + j] += sum;
130 }
131 }
132 }
133 }
134 }
135
136 Array::from_vec(result, vec![m, n])
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn test_add_simd() {
146 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
147 let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![4]);
148 let c = a.add_simd(&b);
149 assert_eq!(c.data, vec![6.0, 8.0, 10.0, 12.0]);
150 }
151
152 #[test]
153 fn test_dot_1d_simd() {
154 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
155 let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![4]);
156 let result = a.dot_1d_simd(&b);
157 assert_eq!(result, 70.0);
158 }
159
160 #[test]
161 fn test_matmul_parallel() {
162 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
163 let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]);
164 let c = a.matmul_parallel(&b);
165 assert_eq!(c.data, vec![19.0, 22.0, 43.0, 50.0]);
166 }
167
168 #[test]
169 fn test_matmul_blocked() {
170 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
171 let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2]);
172 let c = a.matmul_blocked(&b, 2);
173 assert_eq!(c.data, vec![19.0, 22.0, 43.0, 50.0]);
174 }
175}