ghostflow_core/ops/
reduction.rs

1//! Reduction operations (sum, mean, max, min, etc.)
2
3use crate::tensor::Tensor;
4use crate::error::{GhostError, Result};
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7
8impl Tensor {
9    /// Sum all elements
10    pub fn sum(&self) -> Tensor {
11        let data = self.data_f32();
12        let sum: f32 = data.iter().sum();
13        Tensor::from_slice(&[sum], &[]).unwrap()
14    }
15
16    /// Sum along dimension
17    pub fn sum_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
18        self.reduce_dim(dim, keepdim, |slice| slice.iter().sum())
19    }
20
21    /// Mean of all elements
22    pub fn mean(&self) -> Tensor {
23        let data = self.data_f32();
24        let sum: f32 = data.iter().sum();
25        let mean = sum / data.len() as f32;
26        Tensor::from_slice(&[mean], &[]).unwrap()
27    }
28
29    /// Mean along dimension
30    pub fn mean_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
31        let dim_size = self.dims()[dim] as f32;
32        self.reduce_dim(dim, keepdim, |slice| {
33            slice.iter().sum::<f32>() / dim_size
34        })
35    }
36
37    /// Maximum element
38    pub fn max(&self) -> Tensor {
39        let data = self.data_f32();
40        let max = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
41        Tensor::from_slice(&[max], &[]).unwrap()
42    }
43
44    /// Maximum along dimension
45    pub fn max_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
46        self.reduce_dim(dim, keepdim, |slice| {
47            slice.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
48        })
49    }
50
51    /// Minimum element
52    pub fn min(&self) -> Tensor {
53        let data = self.data_f32();
54        let min = data.iter().cloned().fold(f32::INFINITY, f32::min);
55        Tensor::from_slice(&[min], &[]).unwrap()
56    }
57
58    /// Minimum along dimension
59    pub fn min_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
60        self.reduce_dim(dim, keepdim, |slice| {
61            slice.iter().cloned().fold(f32::INFINITY, f32::min)
62        })
63    }
64
65    /// Product of all elements
66    pub fn prod(&self) -> Tensor {
67        let data = self.data_f32();
68        let prod: f32 = data.iter().product();
69        Tensor::from_slice(&[prod], &[]).unwrap()
70    }
71
72    /// Variance of all elements
73    pub fn var(&self, unbiased: bool) -> Tensor {
74        let data = self.data_f32();
75        let n = data.len() as f32;
76        let mean: f32 = data.iter().sum::<f32>() / n;
77        let var: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>();
78        let divisor = if unbiased { n - 1.0 } else { n };
79        Tensor::from_slice(&[var / divisor], &[]).unwrap()
80    }
81
82    /// Standard deviation
83    pub fn std(&self, unbiased: bool) -> Tensor {
84        let var = self.var(unbiased);
85        var.sqrt()
86    }
87
88    /// Argmax - index of maximum element
89    pub fn argmax(&self) -> Tensor {
90        let data = self.data_f32();
91        let (idx, _) = data.iter()
92            .enumerate()
93            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
94            .unwrap();
95        Tensor::from_slice(&[idx as f32], &[]).unwrap()
96    }
97
98    /// Argmax along dimension
99    pub fn argmax_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
100        self.reduce_dim_with_index(dim, keepdim, |slice| {
101            slice.iter()
102                .enumerate()
103                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
104                .map(|(i, _)| i as f32)
105                .unwrap_or(0.0)
106        })
107    }
108
109    /// Argmin - index of minimum element
110    pub fn argmin(&self) -> Tensor {
111        let data = self.data_f32();
112        let (idx, _) = data.iter()
113            .enumerate()
114            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
115            .unwrap();
116        Tensor::from_slice(&[idx as f32], &[]).unwrap()
117    }
118
119    /// Generic reduction along dimension
120    fn reduce_dim<F>(&self, dim: usize, keepdim: bool, reducer: F) -> Result<Tensor>
121    where
122        F: Fn(&[f32]) -> f32 + Sync,
123    {
124        if dim >= self.ndim() {
125            return Err(GhostError::DimOutOfBounds {
126                dim,
127                ndim: self.ndim(),
128            });
129        }
130
131        let dims = self.dims();
132        let dim_size = dims[dim];
133        
134        // Compute output shape
135        let mut out_shape: Vec<usize> = dims.iter()
136            .enumerate()
137            .filter(|&(i, _)| i != dim || keepdim)
138            .map(|(i, &d)| if i == dim { 1 } else { d })
139            .collect();
140        
141        if out_shape.is_empty() {
142            out_shape.push(1);
143        }
144
145        let data = self.data_f32();
146        let out_numel: usize = out_shape.iter().product();
147        
148        // Compute strides for iteration
149        let stride: usize = dims[dim + 1..].iter().product();
150        let outer_stride = dim_size * stride;
151        let _outer_size: usize = dims[..dim].iter().product();
152
153        let result: Vec<f32> = (0..out_numel)
154            .into_iter()
155            .map(|out_idx| {
156                let outer = out_idx / stride;
157                let inner = out_idx % stride;
158                
159                let slice: Vec<f32> = (0..dim_size)
160                    .map(|d| {
161                        let idx = outer * outer_stride + d * stride + inner;
162                        data[idx]
163                    })
164                    .collect();
165                
166                reducer(&slice)
167            })
168            .collect();
169
170        Tensor::from_slice(&result, &out_shape)
171    }
172
173    /// Reduction with index output
174    fn reduce_dim_with_index<F>(&self, dim: usize, keepdim: bool, reducer: F) -> Result<Tensor>
175    where
176        F: Fn(&[f32]) -> f32 + Sync,
177    {
178        // Same as reduce_dim but returns indices
179        self.reduce_dim(dim, keepdim, reducer)
180    }
181
182    /// L2 norm
183    pub fn norm(&self) -> Tensor {
184        let data = self.data_f32();
185        let sum_sq: f32 = data.iter().map(|&x| x * x).sum();
186        Tensor::from_slice(&[sum_sq.sqrt()], &[]).unwrap()
187    }
188
189    /// L1 norm
190    pub fn norm_l1(&self) -> Tensor {
191        let data = self.data_f32();
192        let sum: f32 = data.iter().map(|&x| x.abs()).sum();
193        Tensor::from_slice(&[sum], &[]).unwrap()
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_sum() {
203        let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
204        let sum = t.sum();
205        assert_eq!(sum.data_f32()[0], 10.0);
206    }
207
208    #[test]
209    fn test_mean() {
210        let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]).unwrap();
211        let mean = t.mean();
212        assert_eq!(mean.data_f32()[0], 2.5);
213    }
214
215    #[test]
216    fn test_max_min() {
217        let t = Tensor::from_slice(&[1.0f32, 5.0, 2.0, 4.0], &[4]).unwrap();
218        assert_eq!(t.max().data_f32()[0], 5.0);
219        assert_eq!(t.min().data_f32()[0], 1.0);
220    }
221
222    #[test]
223    fn test_var_std() {
224        let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
225        let var = t.var(false);
226        assert!((var.data_f32()[0] - 2.0).abs() < 0.001);
227    }
228}