1use crate::tensor::Tensor;
4use crate::error::{GhostError, Result};
5use rayon::prelude::*;
6
7impl Tensor {
8 pub fn sum(&self) -> Tensor {
10 let data = self.data_f32();
11 let sum: f32 = data.par_iter().sum();
12 Tensor::from_slice(&[sum], &[]).unwrap()
13 }
14
15 pub fn sum_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
17 self.reduce_dim(dim, keepdim, |slice| slice.iter().sum())
18 }
19
20 pub fn mean(&self) -> Tensor {
22 let data = self.data_f32();
23 let sum: f32 = data.par_iter().sum();
24 let mean = sum / data.len() as f32;
25 Tensor::from_slice(&[mean], &[]).unwrap()
26 }
27
28 pub fn mean_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
30 let dim_size = self.dims()[dim] as f32;
31 self.reduce_dim(dim, keepdim, |slice| {
32 slice.iter().sum::<f32>() / dim_size
33 })
34 }
35
36 pub fn max(&self) -> Tensor {
38 let data = self.data_f32();
39 let max = data.par_iter().cloned().reduce(|| f32::NEG_INFINITY, f32::max);
40 Tensor::from_slice(&[max], &[]).unwrap()
41 }
42
43 pub fn max_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
45 self.reduce_dim(dim, keepdim, |slice| {
46 slice.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
47 })
48 }
49
50 pub fn min(&self) -> Tensor {
52 let data = self.data_f32();
53 let min = data.par_iter().cloned().reduce(|| f32::INFINITY, f32::min);
54 Tensor::from_slice(&[min], &[]).unwrap()
55 }
56
57 pub fn min_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
59 self.reduce_dim(dim, keepdim, |slice| {
60 slice.iter().cloned().fold(f32::INFINITY, f32::min)
61 })
62 }
63
64 pub fn prod(&self) -> Tensor {
66 let data = self.data_f32();
67 let prod: f32 = data.par_iter().product();
68 Tensor::from_slice(&[prod], &[]).unwrap()
69 }
70
71 pub fn var(&self, unbiased: bool) -> Tensor {
73 let data = self.data_f32();
74 let n = data.len() as f32;
75 let mean: f32 = data.par_iter().sum::<f32>() / n;
76 let var: f32 = data.par_iter().map(|&x| (x - mean).powi(2)).sum::<f32>();
77 let divisor = if unbiased { n - 1.0 } else { n };
78 Tensor::from_slice(&[var / divisor], &[]).unwrap()
79 }
80
81 pub fn std(&self, unbiased: bool) -> Tensor {
83 let var = self.var(unbiased);
84 var.sqrt()
85 }
86
87 pub fn argmax(&self) -> Tensor {
89 let data = self.data_f32();
90 let (idx, _) = data.iter()
91 .enumerate()
92 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
93 .unwrap();
94 Tensor::from_slice(&[idx as f32], &[]).unwrap()
95 }
96
97 pub fn argmax_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
99 self.reduce_dim_with_index(dim, keepdim, |slice| {
100 slice.iter()
101 .enumerate()
102 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
103 .map(|(i, _)| i as f32)
104 .unwrap_or(0.0)
105 })
106 }
107
108 pub fn argmin(&self) -> Tensor {
110 let data = self.data_f32();
111 let (idx, _) = data.iter()
112 .enumerate()
113 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
114 .unwrap();
115 Tensor::from_slice(&[idx as f32], &[]).unwrap()
116 }
117
118 fn reduce_dim<F>(&self, dim: usize, keepdim: bool, reducer: F) -> Result<Tensor>
120 where
121 F: Fn(&[f32]) -> f32 + Sync,
122 {
123 if dim >= self.ndim() {
124 return Err(GhostError::DimOutOfBounds {
125 dim,
126 ndim: self.ndim(),
127 });
128 }
129
130 let dims = self.dims();
131 let dim_size = dims[dim];
132
133 let mut out_shape: Vec<usize> = dims.iter()
135 .enumerate()
136 .filter(|&(i, _)| i != dim || keepdim)
137 .map(|(i, &d)| if i == dim { 1 } else { d })
138 .collect();
139
140 if out_shape.is_empty() {
141 out_shape.push(1);
142 }
143
144 let data = self.data_f32();
145 let out_numel: usize = out_shape.iter().product();
146
147 let stride: usize = dims[dim + 1..].iter().product();
149 let outer_stride = dim_size * stride;
150 let _outer_size: usize = dims[..dim].iter().product();
151
152 let result: Vec<f32> = (0..out_numel)
153 .into_par_iter()
154 .map(|out_idx| {
155 let outer = out_idx / stride;
156 let inner = out_idx % stride;
157
158 let slice: Vec<f32> = (0..dim_size)
159 .map(|d| {
160 let idx = outer * outer_stride + d * stride + inner;
161 data[idx]
162 })
163 .collect();
164
165 reducer(&slice)
166 })
167 .collect();
168
169 Tensor::from_slice(&result, &out_shape)
170 }
171
172 fn reduce_dim_with_index<F>(&self, dim: usize, keepdim: bool, reducer: F) -> Result<Tensor>
174 where
175 F: Fn(&[f32]) -> f32 + Sync,
176 {
177 self.reduce_dim(dim, keepdim, reducer)
179 }
180
181 pub fn norm(&self) -> Tensor {
183 let data = self.data_f32();
184 let sum_sq: f32 = data.par_iter().map(|&x| x * x).sum();
185 Tensor::from_slice(&[sum_sq.sqrt()], &[]).unwrap()
186 }
187
188 pub fn norm_l1(&self) -> Tensor {
190 let data = self.data_f32();
191 let sum: f32 = data.par_iter().map(|&x| x.abs()).sum();
192 Tensor::from_slice(&[sum], &[]).unwrap()
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn test_sum() {
202 let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
203 let sum = t.sum();
204 assert_eq!(sum.data_f32()[0], 10.0);
205 }
206
207 #[test]
208 fn test_mean() {
209 let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]).unwrap();
210 let mean = t.mean();
211 assert_eq!(mean.data_f32()[0], 2.5);
212 }
213
214 #[test]
215 fn test_max_min() {
216 let t = Tensor::from_slice(&[1.0f32, 5.0, 2.0, 4.0], &[4]).unwrap();
217 assert_eq!(t.max().data_f32()[0], 5.0);
218 assert_eq!(t.min().data_f32()[0], 1.0);
219 }
220
221 #[test]
222 fn test_var_std() {
223 let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
224 let var = t.var(false);
225 assert!((var.data_f32()[0] - 2.0).abs() < 0.001);
226 }
227}