1use super::Array;
2use core::f64;
3use std::ops::{Add, Div, Mul, Sub};
4
5impl Array<f64> {
6 pub fn shapes_broadcastable(shape1: &[usize], shape2: &[usize]) -> bool {
8 let max_dims = shape1.len().max(shape2.len());
9 for i in 0..max_dims {
10 let dim1 = shape1.get(shape1.len().wrapping_sub(i + 1)).unwrap_or(&1);
11 let dim2 = shape2.get(shape2.len().wrapping_sub(i + 1)).unwrap_or(&1);
12 if *dim1 != *dim2 && *dim1 != 1 && *dim2 != 1 {
13 return false;
14 }
15 }
16 true
17 }
18
19 pub fn broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
21 if !Self::shapes_broadcastable(shape1, shape2) {
22 return None;
23 }
24
25 let max_dims = shape1.len().max(shape2.len());
26 let mut result_shape = vec![1; max_dims];
27
28 for i in 0..max_dims {
29 let dim1 = shape1.get(shape1.len().wrapping_sub(i + 1)).unwrap_or(&1);
30 let dim2 = shape2.get(shape2.len().wrapping_sub(i + 1)).unwrap_or(&1);
31 result_shape[max_dims - i - 1] = (*dim1).max(*dim2);
32 }
33
34 Some(result_shape)
35 }
36
37 pub fn get_broadcasted(&self, indices: &[usize], target_shape: &[usize]) -> &f64 {
39 let mut actual_indices = vec![0; self.shape.len()];
40 let shape_offset = target_shape.len() - self.shape.len();
41
42 for (i, &target_idx) in indices.iter().enumerate() {
43 if i >= shape_offset {
44 let self_dim_idx = i - shape_offset;
45 if self_dim_idx < self.shape.len() {
46 if self.shape[self_dim_idx] == 1 {
47 actual_indices[self_dim_idx] = 0; } else {
49 actual_indices[self_dim_idx] = target_idx;
50 }
51 }
52 }
53 }
54
55 &self[actual_indices.as_slice()]
56 }
57
58 pub fn add_broadcast(&self, other: &Array<f64>) -> Option<Array<f64>> {
60 let result_shape = Self::broadcast_shapes(&self.shape, &other.shape)?;
61 let mut result = Array::zeros(result_shape.clone());
62
63 let total_elements: usize = result_shape.iter().product();
64 for flat_idx in 0..total_elements {
65 let indices = Self::unravel_index(flat_idx, &result_shape);
66 let val1 = self.get_broadcasted(&indices, &result_shape);
67 let val2 = other.get_broadcasted(&indices, &result_shape);
68 result[indices.as_slice()] = val1 + val2;
69 }
70
71 Some(result)
72 }
73
74 pub fn sub_broadcast(&self, other: &Array<f64>) -> Option<Array<f64>> {
76 let result_shape = Self::broadcast_shapes(&self.shape, &other.shape)?;
77 let mut result = Array::zeros(result_shape.clone());
78
79 let total_elements: usize = result_shape.iter().product();
80 for flat_idx in 0..total_elements {
81 let indices = Self::unravel_index(flat_idx, &result_shape);
82 let val1 = self.get_broadcasted(&indices, &result_shape);
83 let val2 = other.get_broadcasted(&indices, &result_shape);
84 result[indices.as_slice()] = val1 - val2;
85 }
86
87 Some(result)
88 }
89
90 pub fn mul_broadcast(&self, other: &Array<f64>) -> Option<Array<f64>> {
92 let result_shape = Self::broadcast_shapes(&self.shape, &other.shape)?;
93 let mut result = Array::zeros(result_shape.clone());
94
95 let total_elements: usize = result_shape.iter().product();
96 for flat_idx in 0..total_elements {
97 let indices = Self::unravel_index(flat_idx, &result_shape);
98 let val1 = self.get_broadcasted(&indices, &result_shape);
99 let val2 = other.get_broadcasted(&indices, &result_shape);
100 result[indices.as_slice()] = val1 * val2;
101 }
102
103 Some(result)
104 }
105
106 pub fn div_broadcast(&self, other: &Array<f64>) -> Option<Array<f64>> {
108 let result_shape = Self::broadcast_shapes(&self.shape, &other.shape)?;
109 let mut result = Array::zeros(result_shape.clone());
110
111 let total_elements: usize = result_shape.iter().product();
112 for flat_idx in 0..total_elements {
113 let indices = Self::unravel_index(flat_idx, &result_shape);
114 let val1 = self.get_broadcasted(&indices, &result_shape);
115 let val2 = other.get_broadcasted(&indices, &result_shape);
116 result[indices.as_slice()] = val1 / val2;
117 }
118
119 Some(result)
120 }
121
122 pub fn unravel_index(flat_index: usize, shape: &[usize]) -> Vec<usize> {
124 let mut indices = vec![0; shape.len()];
125 let mut remaining = flat_index;
126
127 for (i, &_dim_size) in shape.iter().enumerate() {
128 let stride: usize = shape[i + 1..].iter().product();
129 indices[i] = remaining / stride;
130 remaining %= stride;
131 }
132
133 indices
134 }
135
136 pub fn add_scalar(&self, scalar: f64) -> Array<f64> {
138 let data: Vec<f64> = self.data.iter().map(|&x| x + scalar).collect();
139 Array::from_vec(data, self.shape.clone())
140 }
141
142 pub fn sub_scalar(&self, scalar: f64) -> Array<f64> {
143 let data: Vec<f64> = self.data.iter().map(|&x| x - scalar).collect();
144 Array::from_vec(data, self.shape.clone())
145 }
146
147 pub fn mul_scalar(&self, scalar: f64) -> Array<f64> {
148 let data: Vec<f64> = self.data.iter().map(|&x| x * scalar).collect();
149 Array::from_vec(data, self.shape.clone())
150 }
151
152 pub fn div_scalar(&self, scalar: f64) -> Array<f64> {
153 let data: Vec<f64> = self.data.iter().map(|&x| x / scalar).collect();
154 Array::from_vec(data, self.shape.clone())
155 }
156
157 pub fn sum(&self) -> f64 {
159 self.data.iter().sum()
160 }
161
162 pub fn sum_axis(&self, axis: usize) -> Array<f64> {
163 assert!(axis < self.ndim(), "Axis out of bounds");
164
165 let mut result_shape = self.shape.clone();
166 result_shape.remove(axis);
167 if result_shape.is_empty() {
168 result_shape.push(1);
169 }
170
171 let mut result = Array::zeros(result_shape.clone());
172 let result_size: usize = result_shape.iter().product();
173
174 for result_idx in 0..result_size {
175 let mut sum = 0.0;
176 let result_indices = Self::unravel_index(result_idx, &result_shape);
177
178 for i in 0..self.shape[axis] {
179 let mut full_indices = Vec::new();
180 let mut result_iter = result_indices.iter();
181
182 for (dim_idx, _) in self.shape.iter().enumerate() {
183 if dim_idx == axis {
184 full_indices.push(i);
185 } else {
186 full_indices.push(*result_iter.next().unwrap());
187 }
188 }
189
190 sum += self[full_indices.as_slice()];
191 }
192
193 result[result_indices.as_slice()] = sum;
194 }
195 result
196 }
197
198 pub fn mean(&self) -> f64 {
199 self.sum() / self.len() as f64
200 }
201
202 pub fn mean_axis(&self, axis: usize) -> Array<f64> {
203 let sum_result = self.sum_axis(axis);
204 let divisor = self.shape[axis] as f64;
205 sum_result.div_scalar(divisor)
206 }
207
208 pub fn max(&self) -> f64 {
209 self.data
210 .iter()
211 .fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
212 }
213
214 pub fn min(&self) -> f64 {
215 self.data.iter().fold(f64::INFINITY, |acc, &x| acc.min(x))
216 }
217
218 pub fn exp(&self) -> Array<f64> {
220 let data: Vec<f64> = self.data.iter().map(|&x| x.exp()).collect();
221 Array::from_vec(data, self.shape.clone())
222 }
223
224 pub fn ln(&self) -> Array<f64> {
225 let data: Vec<f64> = self.data.iter().map(|&x| x.ln()).collect();
226 Array::from_vec(data, self.shape.clone())
227 }
228
229 pub fn sin(&self) -> Array<f64> {
230 let data: Vec<f64> = self.data.iter().map(|&x| x.sin()).collect();
231 Array::from_vec(data, self.shape.clone())
232 }
233
234 pub fn cos(&self) -> Array<f64> {
235 let data: Vec<f64> = self.data.iter().map(|&x| x.cos()).collect();
236 Array::from_vec(data, self.shape.clone())
237 }
238
239 pub fn sqrt(&self) -> Array<f64> {
240 let data: Vec<f64> = self.data.iter().map(|&x| x.sqrt()).collect();
241 Array::from_vec(data, self.shape.clone())
242 }
243
244 pub fn pow(&self, exponent: f64) -> Array<f64> {
245 let data: Vec<f64> = self.data.iter().map(|&x| x.powf(exponent)).collect();
246 Array::from_vec(data, self.shape.clone())
247 }
248}
249
250impl Add for &Array<f64> {
252 type Output = Array<f64>;
253 fn add(self, rhs: Self) -> Self::Output {
254 self.add_broadcast(rhs).expect("Shapes not broadcastable")
255 }
256}
257
258impl Sub for &Array<f64> {
259 type Output = Array<f64>;
260 fn sub(self, rhs: Self) -> Self::Output {
261 self.sub_broadcast(rhs).expect("Shapes not broadcastable")
262 }
263}
264
265impl Mul for &Array<f64> {
266 type Output = Array<f64>;
267 fn mul(self, rhs: Self) -> Self::Output {
268 self.mul_broadcast(rhs).expect("Shapes not broadcastable")
269 }
270}
271
272impl Div for &Array<f64> {
273 type Output = Array<f64>;
274 fn div(self, rhs: Self) -> Self::Output {
275 self.div_broadcast(rhs).expect("Shapes not broadcastable")
276 }
277}
278
279impl Add<f64> for &Array<f64> {
281 type Output = Array<f64>;
282 fn add(self, scalar: f64) -> Self::Output {
283 self.add_scalar(scalar)
284 }
285}
286
287impl Sub<f64> for &Array<f64> {
288 type Output = Array<f64>;
289 fn sub(self, scalar: f64) -> Self::Output {
290 self.sub_scalar(scalar)
291 }
292}
293
294impl Mul<f64> for &Array<f64> {
295 type Output = Array<f64>;
296 fn mul(self, scalar: f64) -> Self::Output {
297 self.mul_scalar(scalar)
298 }
299}
300
301impl Div<f64> for &Array<f64> {
302 type Output = Array<f64>;
303 fn div(self, scalar: f64) -> Self::Output {
304 self.div_scalar(scalar)
305 }
306}