ghostflow_core/ops/
reduction.rs

1//! Reduction operations (sum, mean, max, min, etc.)
2
3use crate::tensor::Tensor;
4use crate::error::{GhostError, Result};
5use rayon::prelude::*;
6
7impl Tensor {
8    /// Sum all elements
9    pub fn sum(&self) -> Tensor {
10        let data = self.data_f32();
11        let sum: f32 = data.par_iter().sum();
12        Tensor::from_slice(&[sum], &[]).unwrap()
13    }
14
15    /// Sum along dimension
16    pub fn sum_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
17        self.reduce_dim(dim, keepdim, |slice| slice.iter().sum())
18    }
19
20    /// Mean of all elements
21    pub fn mean(&self) -> Tensor {
22        let data = self.data_f32();
23        let sum: f32 = data.par_iter().sum();
24        let mean = sum / data.len() as f32;
25        Tensor::from_slice(&[mean], &[]).unwrap()
26    }
27
28    /// Mean along dimension
29    pub fn mean_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
30        let dim_size = self.dims()[dim] as f32;
31        self.reduce_dim(dim, keepdim, |slice| {
32            slice.iter().sum::<f32>() / dim_size
33        })
34    }
35
36    /// Maximum element
37    pub fn max(&self) -> Tensor {
38        let data = self.data_f32();
39        let max = data.par_iter().cloned().reduce(|| f32::NEG_INFINITY, f32::max);
40        Tensor::from_slice(&[max], &[]).unwrap()
41    }
42
43    /// Maximum along dimension
44    pub fn max_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
45        self.reduce_dim(dim, keepdim, |slice| {
46            slice.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
47        })
48    }
49
50    /// Minimum element
51    pub fn min(&self) -> Tensor {
52        let data = self.data_f32();
53        let min = data.par_iter().cloned().reduce(|| f32::INFINITY, f32::min);
54        Tensor::from_slice(&[min], &[]).unwrap()
55    }
56
57    /// Minimum along dimension
58    pub fn min_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
59        self.reduce_dim(dim, keepdim, |slice| {
60            slice.iter().cloned().fold(f32::INFINITY, f32::min)
61        })
62    }
63
64    /// Product of all elements
65    pub fn prod(&self) -> Tensor {
66        let data = self.data_f32();
67        let prod: f32 = data.par_iter().product();
68        Tensor::from_slice(&[prod], &[]).unwrap()
69    }
70
71    /// Variance of all elements
72    pub fn var(&self, unbiased: bool) -> Tensor {
73        let data = self.data_f32();
74        let n = data.len() as f32;
75        let mean: f32 = data.par_iter().sum::<f32>() / n;
76        let var: f32 = data.par_iter().map(|&x| (x - mean).powi(2)).sum::<f32>();
77        let divisor = if unbiased { n - 1.0 } else { n };
78        Tensor::from_slice(&[var / divisor], &[]).unwrap()
79    }
80
81    /// Standard deviation
82    pub fn std(&self, unbiased: bool) -> Tensor {
83        let var = self.var(unbiased);
84        var.sqrt()
85    }
86
87    /// Argmax - index of maximum element
88    pub fn argmax(&self) -> Tensor {
89        let data = self.data_f32();
90        let (idx, _) = data.iter()
91            .enumerate()
92            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
93            .unwrap();
94        Tensor::from_slice(&[idx as f32], &[]).unwrap()
95    }
96
97    /// Argmax along dimension
98    pub fn argmax_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
99        self.reduce_dim_with_index(dim, keepdim, |slice| {
100            slice.iter()
101                .enumerate()
102                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
103                .map(|(i, _)| i as f32)
104                .unwrap_or(0.0)
105        })
106    }
107
108    /// Argmin - index of minimum element
109    pub fn argmin(&self) -> Tensor {
110        let data = self.data_f32();
111        let (idx, _) = data.iter()
112            .enumerate()
113            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
114            .unwrap();
115        Tensor::from_slice(&[idx as f32], &[]).unwrap()
116    }
117
118    /// Generic reduction along dimension
119    fn reduce_dim<F>(&self, dim: usize, keepdim: bool, reducer: F) -> Result<Tensor>
120    where
121        F: Fn(&[f32]) -> f32 + Sync,
122    {
123        if dim >= self.ndim() {
124            return Err(GhostError::DimOutOfBounds {
125                dim,
126                ndim: self.ndim(),
127            });
128        }
129
130        let dims = self.dims();
131        let dim_size = dims[dim];
132        
133        // Compute output shape
134        let mut out_shape: Vec<usize> = dims.iter()
135            .enumerate()
136            .filter(|&(i, _)| i != dim || keepdim)
137            .map(|(i, &d)| if i == dim { 1 } else { d })
138            .collect();
139        
140        if out_shape.is_empty() {
141            out_shape.push(1);
142        }
143
144        let data = self.data_f32();
145        let out_numel: usize = out_shape.iter().product();
146        
147        // Compute strides for iteration
148        let stride: usize = dims[dim + 1..].iter().product();
149        let outer_stride = dim_size * stride;
150        let _outer_size: usize = dims[..dim].iter().product();
151
152        let result: Vec<f32> = (0..out_numel)
153            .into_par_iter()
154            .map(|out_idx| {
155                let outer = out_idx / stride;
156                let inner = out_idx % stride;
157                
158                let slice: Vec<f32> = (0..dim_size)
159                    .map(|d| {
160                        let idx = outer * outer_stride + d * stride + inner;
161                        data[idx]
162                    })
163                    .collect();
164                
165                reducer(&slice)
166            })
167            .collect();
168
169        Tensor::from_slice(&result, &out_shape)
170    }
171
172    /// Reduction with index output
173    fn reduce_dim_with_index<F>(&self, dim: usize, keepdim: bool, reducer: F) -> Result<Tensor>
174    where
175        F: Fn(&[f32]) -> f32 + Sync,
176    {
177        // Same as reduce_dim but returns indices
178        self.reduce_dim(dim, keepdim, reducer)
179    }
180
181    /// L2 norm
182    pub fn norm(&self) -> Tensor {
183        let data = self.data_f32();
184        let sum_sq: f32 = data.par_iter().map(|&x| x * x).sum();
185        Tensor::from_slice(&[sum_sq.sqrt()], &[]).unwrap()
186    }
187
188    /// L1 norm
189    pub fn norm_l1(&self) -> Tensor {
190        let data = self.data_f32();
191        let sum: f32 = data.par_iter().map(|&x| x.abs()).sum();
192        Tensor::from_slice(&[sum], &[]).unwrap()
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn test_sum() {
202        let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
203        let sum = t.sum();
204        assert_eq!(sum.data_f32()[0], 10.0);
205    }
206
207    #[test]
208    fn test_mean() {
209        let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]).unwrap();
210        let mean = t.mean();
211        assert_eq!(mean.data_f32()[0], 2.5);
212    }
213
214    #[test]
215    fn test_max_min() {
216        let t = Tensor::from_slice(&[1.0f32, 5.0, 2.0, 4.0], &[4]).unwrap();
217        assert_eq!(t.max().data_f32()[0], 5.0);
218        assert_eq!(t.min().data_f32()[0], 1.0);
219    }
220
221    #[test]
222    fn test_var_std() {
223        let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
224        let var = t.var(false);
225        assert!((var.data_f32()[0] - 2.0).abs() < 0.001);
226    }
227}