rustframes/array/
ops.rs

1use super::Array;
2use core::f64;
3use std::ops::{Add, Div, Mul, Sub};
4
5impl Array<f64> {
6    /// Check if two shapes are broadcastable
7    pub fn shapes_broadcastable(shape1: &[usize], shape2: &[usize]) -> bool {
8        let max_dims = shape1.len().max(shape2.len());
9        for i in 0..max_dims {
10            let dim1 = shape1.get(shape1.len().wrapping_sub(i + 1)).unwrap_or(&1);
11            let dim2 = shape2.get(shape2.len().wrapping_sub(i + 1)).unwrap_or(&1);
12            if *dim1 != *dim2 && *dim1 != 1 && *dim2 != 1 {
13                return false;
14            }
15        }
16        true
17    }
18
19    /// Compute the resulting shape after broadcasting
20    pub fn broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
21        if !Self::shapes_broadcastable(shape1, shape2) {
22            return None;
23        }
24
25        let max_dims = shape1.len().max(shape2.len());
26        let mut result_shape = vec![1; max_dims];
27
28        for i in 0..max_dims {
29            let dim1 = shape1.get(shape1.len().wrapping_sub(i + 1)).unwrap_or(&1);
30            let dim2 = shape2.get(shape2.len().wrapping_sub(i + 1)).unwrap_or(&1);
31            result_shape[max_dims - i - 1] = (*dim1).max(*dim2);
32        }
33
34        Some(result_shape)
35    }
36
37    /// Get element with broadcasting
38    pub fn get_broadcasted(&self, indices: &[usize], target_shape: &[usize]) -> &f64 {
39        let mut actual_indices = vec![0; self.shape.len()];
40        let shape_offset = target_shape.len() - self.shape.len();
41
42        for (i, &target_idx) in indices.iter().enumerate() {
43            if i >= shape_offset {
44                let self_dim_idx = i - shape_offset;
45                if self_dim_idx < self.shape.len() {
46                    if self.shape[self_dim_idx] == 1 {
47                        actual_indices[self_dim_idx] = 0; // Broadcast dimension
48                    } else {
49                        actual_indices[self_dim_idx] = target_idx;
50                    }
51                }
52            }
53        }
54
55        &self[actual_indices.as_slice()]
56    }
57
58    /// Element-wise addition with broadcasting
59    pub fn add_broadcast(&self, other: &Array<f64>) -> Option<Array<f64>> {
60        let result_shape = Self::broadcast_shapes(&self.shape, &other.shape)?;
61        let mut result = Array::zeros(result_shape.clone());
62
63        let total_elements: usize = result_shape.iter().product();
64        for flat_idx in 0..total_elements {
65            let indices = Self::unravel_index(flat_idx, &result_shape);
66            let val1 = self.get_broadcasted(&indices, &result_shape);
67            let val2 = other.get_broadcasted(&indices, &result_shape);
68            result[indices.as_slice()] = val1 + val2;
69        }
70
71        Some(result)
72    }
73
74    /// Element-wise subtraction with broadcasting
75    pub fn sub_broadcast(&self, other: &Array<f64>) -> Option<Array<f64>> {
76        let result_shape = Self::broadcast_shapes(&self.shape, &other.shape)?;
77        let mut result = Array::zeros(result_shape.clone());
78
79        let total_elements: usize = result_shape.iter().product();
80        for flat_idx in 0..total_elements {
81            let indices = Self::unravel_index(flat_idx, &result_shape);
82            let val1 = self.get_broadcasted(&indices, &result_shape);
83            let val2 = other.get_broadcasted(&indices, &result_shape);
84            result[indices.as_slice()] = val1 - val2;
85        }
86
87        Some(result)
88    }
89
90    /// Element-wise multiplication with broadcasting
91    pub fn mul_broadcast(&self, other: &Array<f64>) -> Option<Array<f64>> {
92        let result_shape = Self::broadcast_shapes(&self.shape, &other.shape)?;
93        let mut result = Array::zeros(result_shape.clone());
94
95        let total_elements: usize = result_shape.iter().product();
96        for flat_idx in 0..total_elements {
97            let indices = Self::unravel_index(flat_idx, &result_shape);
98            let val1 = self.get_broadcasted(&indices, &result_shape);
99            let val2 = other.get_broadcasted(&indices, &result_shape);
100            result[indices.as_slice()] = val1 * val2;
101        }
102
103        Some(result)
104    }
105
106    /// Element-wise division with broadcasting
107    pub fn div_broadcast(&self, other: &Array<f64>) -> Option<Array<f64>> {
108        let result_shape = Self::broadcast_shapes(&self.shape, &other.shape)?;
109        let mut result = Array::zeros(result_shape.clone());
110
111        let total_elements: usize = result_shape.iter().product();
112        for flat_idx in 0..total_elements {
113            let indices = Self::unravel_index(flat_idx, &result_shape);
114            let val1 = self.get_broadcasted(&indices, &result_shape);
115            let val2 = other.get_broadcasted(&indices, &result_shape);
116            result[indices.as_slice()] = val1 / val2;
117        }
118
119        Some(result)
120    }
121
122    /// Convert flat index to multi-dimensional index
123    pub fn unravel_index(flat_index: usize, shape: &[usize]) -> Vec<usize> {
124        let mut indices = vec![0; shape.len()];
125        let mut remaining = flat_index;
126
127        for (i, &_dim_size) in shape.iter().enumerate() {
128            let stride: usize = shape[i + 1..].iter().product();
129            indices[i] = remaining / stride;
130            remaining %= stride;
131        }
132
133        indices
134    }
135
136    /// Scalar operations
137    pub fn add_scalar(&self, scalar: f64) -> Array<f64> {
138        let data: Vec<f64> = self.data.iter().map(|&x| x + scalar).collect();
139        Array::from_vec(data, self.shape.clone())
140    }
141
142    pub fn sub_scalar(&self, scalar: f64) -> Array<f64> {
143        let data: Vec<f64> = self.data.iter().map(|&x| x - scalar).collect();
144        Array::from_vec(data, self.shape.clone())
145    }
146
147    pub fn mul_scalar(&self, scalar: f64) -> Array<f64> {
148        let data: Vec<f64> = self.data.iter().map(|&x| x * scalar).collect();
149        Array::from_vec(data, self.shape.clone())
150    }
151
152    pub fn div_scalar(&self, scalar: f64) -> Array<f64> {
153        let data: Vec<f64> = self.data.iter().map(|&x| x / scalar).collect();
154        Array::from_vec(data, self.shape.clone())
155    }
156
157    /// Reduction operations
158    pub fn sum(&self) -> f64 {
159        self.data.iter().sum()
160    }
161
162    pub fn sum_axis(&self, axis: usize) -> Array<f64> {
163        assert!(axis < self.ndim(), "Axis out of bounds");
164
165        let mut result_shape = self.shape.clone();
166        result_shape.remove(axis);
167        if result_shape.is_empty() {
168            result_shape.push(1);
169        }
170
171        let mut result = Array::zeros(result_shape.clone());
172        let result_size: usize = result_shape.iter().product();
173
174        for result_idx in 0..result_size {
175            let mut sum = 0.0;
176            let result_indices = Self::unravel_index(result_idx, &result_shape);
177
178            for i in 0..self.shape[axis] {
179                let mut full_indices = Vec::new();
180                let mut result_iter = result_indices.iter();
181
182                for (dim_idx, _) in self.shape.iter().enumerate() {
183                    if dim_idx == axis {
184                        full_indices.push(i);
185                    } else {
186                        full_indices.push(*result_iter.next().unwrap());
187                    }
188                }
189
190                sum += self[full_indices.as_slice()];
191            }
192
193            result[result_indices.as_slice()] = sum;
194        }
195        result
196    }
197
198    pub fn mean(&self) -> f64 {
199        self.sum() / self.len() as f64
200    }
201
202    pub fn mean_axis(&self, axis: usize) -> Array<f64> {
203        let sum_result = self.sum_axis(axis);
204        let divisor = self.shape[axis] as f64;
205        sum_result.div_scalar(divisor)
206    }
207
208    pub fn max(&self) -> f64 {
209        self.data
210            .iter()
211            .fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
212    }
213
214    pub fn min(&self) -> f64 {
215        self.data.iter().fold(f64::INFINITY, |acc, &x| acc.min(x))
216    }
217
218    /// Element-wise mathematical functions
219    pub fn exp(&self) -> Array<f64> {
220        let data: Vec<f64> = self.data.iter().map(|&x| x.exp()).collect();
221        Array::from_vec(data, self.shape.clone())
222    }
223
224    pub fn ln(&self) -> Array<f64> {
225        let data: Vec<f64> = self.data.iter().map(|&x| x.ln()).collect();
226        Array::from_vec(data, self.shape.clone())
227    }
228
229    pub fn sin(&self) -> Array<f64> {
230        let data: Vec<f64> = self.data.iter().map(|&x| x.sin()).collect();
231        Array::from_vec(data, self.shape.clone())
232    }
233
234    pub fn cos(&self) -> Array<f64> {
235        let data: Vec<f64> = self.data.iter().map(|&x| x.cos()).collect();
236        Array::from_vec(data, self.shape.clone())
237    }
238
239    pub fn sqrt(&self) -> Array<f64> {
240        let data: Vec<f64> = self.data.iter().map(|&x| x.sqrt()).collect();
241        Array::from_vec(data, self.shape.clone())
242    }
243
244    pub fn pow(&self, exponent: f64) -> Array<f64> {
245        let data: Vec<f64> = self.data.iter().map(|&x| x.powf(exponent)).collect();
246        Array::from_vec(data, self.shape.clone())
247    }
248}
249
250// Operator implementations using broadcasting
251impl Add for &Array<f64> {
252    type Output = Array<f64>;
253    fn add(self, rhs: Self) -> Self::Output {
254        self.add_broadcast(rhs).expect("Shapes not broadcastable")
255    }
256}
257
258impl Sub for &Array<f64> {
259    type Output = Array<f64>;
260    fn sub(self, rhs: Self) -> Self::Output {
261        self.sub_broadcast(rhs).expect("Shapes not broadcastable")
262    }
263}
264
265impl Mul for &Array<f64> {
266    type Output = Array<f64>;
267    fn mul(self, rhs: Self) -> Self::Output {
268        self.mul_broadcast(rhs).expect("Shapes not broadcastable")
269    }
270}
271
272impl Div for &Array<f64> {
273    type Output = Array<f64>;
274    fn div(self, rhs: Self) -> Self::Output {
275        self.div_broadcast(rhs).expect("Shapes not broadcastable")
276    }
277}
278
279// Scalar operations using trait implementations
280impl Add<f64> for &Array<f64> {
281    type Output = Array<f64>;
282    fn add(self, scalar: f64) -> Self::Output {
283        self.add_scalar(scalar)
284    }
285}
286
287impl Sub<f64> for &Array<f64> {
288    type Output = Array<f64>;
289    fn sub(self, scalar: f64) -> Self::Output {
290        self.sub_scalar(scalar)
291    }
292}
293
294impl Mul<f64> for &Array<f64> {
295    type Output = Array<f64>;
296    fn mul(self, scalar: f64) -> Self::Output {
297        self.mul_scalar(scalar)
298    }
299}
300
301impl Div<f64> for &Array<f64> {
302    type Output = Array<f64>;
303    fn div(self, scalar: f64) -> Self::Output {
304        self.div_scalar(scalar)
305    }
306}