1use crate::tensor::Tensor;
4use crate::error::{GhostError, Result};
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7
8impl Tensor {
9 pub fn sum(&self) -> Tensor {
11 let data = self.data_f32();
12 let sum: f32 = data.iter().sum();
13 Tensor::from_slice(&[sum], &[]).unwrap()
14 }
15
16 pub fn sum_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
18 self.reduce_dim(dim, keepdim, |slice| slice.iter().sum())
19 }
20
21 pub fn mean(&self) -> Tensor {
23 let data = self.data_f32();
24 let sum: f32 = data.iter().sum();
25 let mean = sum / data.len() as f32;
26 Tensor::from_slice(&[mean], &[]).unwrap()
27 }
28
29 pub fn mean_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
31 let dim_size = self.dims()[dim] as f32;
32 self.reduce_dim(dim, keepdim, |slice| {
33 slice.iter().sum::<f32>() / dim_size
34 })
35 }
36
37 pub fn max(&self) -> Tensor {
39 let data = self.data_f32();
40 let max = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
41 Tensor::from_slice(&[max], &[]).unwrap()
42 }
43
44 pub fn max_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
46 self.reduce_dim(dim, keepdim, |slice| {
47 slice.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
48 })
49 }
50
51 pub fn min(&self) -> Tensor {
53 let data = self.data_f32();
54 let min = data.iter().cloned().fold(f32::INFINITY, f32::min);
55 Tensor::from_slice(&[min], &[]).unwrap()
56 }
57
58 pub fn min_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
60 self.reduce_dim(dim, keepdim, |slice| {
61 slice.iter().cloned().fold(f32::INFINITY, f32::min)
62 })
63 }
64
65 pub fn prod(&self) -> Tensor {
67 let data = self.data_f32();
68 let prod: f32 = data.iter().product();
69 Tensor::from_slice(&[prod], &[]).unwrap()
70 }
71
72 pub fn var(&self, unbiased: bool) -> Tensor {
74 let data = self.data_f32();
75 let n = data.len() as f32;
76 let mean: f32 = data.iter().sum::<f32>() / n;
77 let var: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>();
78 let divisor = if unbiased { n - 1.0 } else { n };
79 Tensor::from_slice(&[var / divisor], &[]).unwrap()
80 }
81
82 pub fn std(&self, unbiased: bool) -> Tensor {
84 let var = self.var(unbiased);
85 var.sqrt()
86 }
87
88 pub fn argmax(&self) -> Tensor {
90 let data = self.data_f32();
91 let (idx, _) = data.iter()
92 .enumerate()
93 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
94 .unwrap();
95 Tensor::from_slice(&[idx as f32], &[]).unwrap()
96 }
97
98 pub fn argmax_dim(&self, dim: usize, keepdim: bool) -> Result<Tensor> {
100 self.reduce_dim_with_index(dim, keepdim, |slice| {
101 slice.iter()
102 .enumerate()
103 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
104 .map(|(i, _)| i as f32)
105 .unwrap_or(0.0)
106 })
107 }
108
109 pub fn argmin(&self) -> Tensor {
111 let data = self.data_f32();
112 let (idx, _) = data.iter()
113 .enumerate()
114 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
115 .unwrap();
116 Tensor::from_slice(&[idx as f32], &[]).unwrap()
117 }
118
119 fn reduce_dim<F>(&self, dim: usize, keepdim: bool, reducer: F) -> Result<Tensor>
121 where
122 F: Fn(&[f32]) -> f32 + Sync,
123 {
124 if dim >= self.ndim() {
125 return Err(GhostError::DimOutOfBounds {
126 dim,
127 ndim: self.ndim(),
128 });
129 }
130
131 let dims = self.dims();
132 let dim_size = dims[dim];
133
134 let mut out_shape: Vec<usize> = dims.iter()
136 .enumerate()
137 .filter(|&(i, _)| i != dim || keepdim)
138 .map(|(i, &d)| if i == dim { 1 } else { d })
139 .collect();
140
141 if out_shape.is_empty() {
142 out_shape.push(1);
143 }
144
145 let data = self.data_f32();
146 let out_numel: usize = out_shape.iter().product();
147
148 let stride: usize = dims[dim + 1..].iter().product();
150 let outer_stride = dim_size * stride;
151 let _outer_size: usize = dims[..dim].iter().product();
152
153 let result: Vec<f32> = (0..out_numel)
154 .into_iter()
155 .map(|out_idx| {
156 let outer = out_idx / stride;
157 let inner = out_idx % stride;
158
159 let slice: Vec<f32> = (0..dim_size)
160 .map(|d| {
161 let idx = outer * outer_stride + d * stride + inner;
162 data[idx]
163 })
164 .collect();
165
166 reducer(&slice)
167 })
168 .collect();
169
170 Tensor::from_slice(&result, &out_shape)
171 }
172
173 fn reduce_dim_with_index<F>(&self, dim: usize, keepdim: bool, reducer: F) -> Result<Tensor>
175 where
176 F: Fn(&[f32]) -> f32 + Sync,
177 {
178 self.reduce_dim(dim, keepdim, reducer)
180 }
181
182 pub fn norm(&self) -> Tensor {
184 let data = self.data_f32();
185 let sum_sq: f32 = data.iter().map(|&x| x * x).sum();
186 Tensor::from_slice(&[sum_sq.sqrt()], &[]).unwrap()
187 }
188
189 pub fn norm_l1(&self) -> Tensor {
191 let data = self.data_f32();
192 let sum: f32 = data.iter().map(|&x| x.abs()).sum();
193 Tensor::from_slice(&[sum], &[]).unwrap()
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_sum() {
203 let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
204 let sum = t.sum();
205 assert_eq!(sum.data_f32()[0], 10.0);
206 }
207
208 #[test]
209 fn test_mean() {
210 let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[4]).unwrap();
211 let mean = t.mean();
212 assert_eq!(mean.data_f32()[0], 2.5);
213 }
214
215 #[test]
216 fn test_max_min() {
217 let t = Tensor::from_slice(&[1.0f32, 5.0, 2.0, 4.0], &[4]).unwrap();
218 assert_eq!(t.max().data_f32()[0], 5.0);
219 assert_eq!(t.min().data_f32()[0], 1.0);
220 }
221
222 #[test]
223 fn test_var_std() {
224 let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
225 let var = t.var(false);
226 assert!((var.data_f32()[0] - 2.0).abs() < 0.001);
227 }
228}