1use crate::ops::arithmetic::ArithmeticOps;
7use crate::ops::shape::ShapeOps;
8use crate::tensor::Tensor;
9use anyhow::{Result, anyhow};
10
11pub trait ReductionOps {
13 fn sum_all(&self) -> Result<Tensor>;
15
16 fn sum_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor>;
18
19 fn mean_all(&self) -> Result<Tensor>;
21
22 fn mean_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor>;
24
25 fn max_all(&self) -> Result<Tensor>;
27
28 fn max_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor>;
30
31 fn min_all(&self) -> Result<Tensor>;
33
34 fn min_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor>;
36
37 fn prod_all(&self) -> Result<Tensor>;
39
40 fn std_all(&self) -> Result<Tensor>;
42
43 fn var_all(&self) -> Result<Tensor>;
45
46 fn norm(&self) -> Result<Tensor>;
48
49 fn norm_p(&self, p: f32) -> Result<Tensor>;
51}
52
53impl ReductionOps for Tensor {
54 fn sum_all(&self) -> Result<Tensor> {
55 let result_candle = self.candle_tensor().sum_all()?;
56
57 let reshaped = if result_candle.dims().is_empty() {
59 result_candle.reshape(&[1])?
60 } else {
61 result_candle
62 };
63
64 Ok(Tensor::from_candle(reshaped, self.dtype(), self.layout()))
65 }
66
67 fn sum_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor> {
68 let shape = self.shape();
69
70 for &dim in dims {
72 if dim >= shape.len() {
73 return Err(anyhow!(
74 "Dimension {} is out of bounds for tensor with {} dimensions",
75 dim,
76 shape.len()
77 ));
78 }
79 }
80
81 let result_candle = if keep_dim {
82 self.candle_tensor().sum_keepdim(dims)?
83 } else {
84 self.candle_tensor().sum(dims)?
85 };
86
87 Ok(Tensor::from_candle(
88 result_candle,
89 self.dtype(),
90 self.layout(),
91 ))
92 }
93
94 fn mean_all(&self) -> Result<Tensor> {
95 let sum = self.sum_all()?;
96 let num_elements = self.numel() as f32;
97 sum.div_scalar(num_elements)
98 }
99
100 fn mean_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor> {
101 let sum = self.sum_dims(dims, keep_dim)?;
102
103 let shape = self.shape();
105 let reduction_size: usize = dims.iter().map(|&dim| shape[dim]).product();
106
107 sum.div_scalar(reduction_size as f32)
108 }
109
110 fn max_all(&self) -> Result<Tensor> {
111 let flattened = self.flatten()?;
112 let result_candle = flattened.candle_tensor().max(0)?;
113
114 let reshaped = if result_candle.dims().is_empty() {
116 result_candle.reshape(&[1])?
117 } else {
118 result_candle
119 };
120
121 Ok(Tensor::from_candle(reshaped, self.dtype(), self.layout()))
122 }
123
124 fn max_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor> {
125 let shape = self.shape();
126
127 for &dim in dims {
129 if dim >= shape.len() {
130 return Err(anyhow!(
131 "Dimension {} is out of bounds for tensor with {} dimensions",
132 dim,
133 shape.len()
134 ));
135 }
136 }
137
138 let mut result = self.clone();
140 let mut sorted_dims = dims.to_vec();
141 sorted_dims.sort_unstable();
142 sorted_dims.reverse(); for &dim in &sorted_dims {
145 let result_candle = if keep_dim {
146 result.candle_tensor().max_keepdim(dim)?
147 } else {
148 result.candle_tensor().max(dim)?
149 };
150 result = Tensor::from_candle(result_candle, result.dtype(), result.layout());
151 }
152
153 Ok(result)
154 }
155
156 fn min_all(&self) -> Result<Tensor> {
157 let flattened = self.flatten()?;
158 let result_candle = flattened.candle_tensor().min(0)?;
159
160 let reshaped = if result_candle.dims().is_empty() {
162 result_candle.reshape(&[1])?
163 } else {
164 result_candle
165 };
166
167 Ok(Tensor::from_candle(reshaped, self.dtype(), self.layout()))
168 }
169
170 fn min_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor> {
171 let shape = self.shape();
172
173 for &dim in dims {
175 if dim >= shape.len() {
176 return Err(anyhow!(
177 "Dimension {} is out of bounds for tensor with {} dimensions",
178 dim,
179 shape.len()
180 ));
181 }
182 }
183
184 let mut result = self.clone();
186 let mut sorted_dims = dims.to_vec();
187 sorted_dims.sort_unstable();
188 sorted_dims.reverse(); for &dim in &sorted_dims {
191 let result_candle = if keep_dim {
192 result.candle_tensor().min_keepdim(dim)?
193 } else {
194 result.candle_tensor().min(dim)?
195 };
196 result = Tensor::from_candle(result_candle, result.dtype(), result.layout());
197 }
198
199 Ok(result)
200 }
201
202 fn prod_all(&self) -> Result<Tensor> {
203 let data = self.to_vec()?;
205 let product = data.iter().fold(1.0, |acc, &x| acc * x);
206
207 Ok(Tensor::from_data(
208 vec![product],
209 vec![1],
210 self.dtype(),
211 self.layout(),
212 )?)
213 }
214
215 fn std_all(&self) -> Result<Tensor> {
216 let variance = self.var_all()?;
217 variance.sqrt()
218 }
219
220 fn var_all(&self) -> Result<Tensor> {
221 let mean = self.mean_all()?;
222 let diff = self.sub(&mean)?;
223 let squared_diff = diff.mul(&diff)?;
224 squared_diff.mean_all()
225 }
226
227 fn norm(&self) -> Result<Tensor> {
228 self.norm_p(2.0)
229 }
230
231 fn norm_p(&self, p: f32) -> Result<Tensor> {
232 if p <= 0.0 {
233 return Err(anyhow!("Norm p must be positive, got {}", p));
234 }
235
236 if p == 1.0 {
237 let abs_values = self.abs()?;
239 abs_values.sum_all()
240 } else if p == 2.0 {
241 let squared = self.mul(self)?;
243 let sum_squared = squared.sum_all()?;
244 sum_squared.sqrt()
245 } else if p.is_infinite() {
246 let abs_values = self.abs()?;
248 abs_values.max_all()
249 } else {
250 let abs_values = self.abs()?;
252 let powered = abs_values.pow(p)?;
253 let sum_powered = powered.sum_all()?;
254 sum_powered.pow(1.0 / p)
255 }
256 }
257}
258
259impl Tensor {
261 pub fn sum_dim(&self, dim: usize, keep_dim: bool) -> Result<Tensor> {
263 self.sum_dims(&[dim], keep_dim)
264 }
265
266 pub fn mean_dim(&self, dim: usize, keep_dim: bool) -> Result<Tensor> {
268 self.mean_dims(&[dim], keep_dim)
269 }
270
271 pub fn max_dim(&self, dim: usize, keep_dim: bool) -> Result<Tensor> {
273 self.max_dims(&[dim], keep_dim)
274 }
275
276 pub fn min_dim(&self, dim: usize, keep_dim: bool) -> Result<Tensor> {
278 self.min_dims(&[dim], keep_dim)
279 }
280
281 pub fn argmax(&self, dim: usize, keep_dim: bool) -> Result<Tensor> {
283 let shape = self.shape();
284
285 if dim >= shape.len() {
286 return Err(anyhow!(
287 "Dimension {} is out of bounds for tensor with {} dimensions",
288 dim,
289 shape.len()
290 ));
291 }
292
293 let result_candle = if keep_dim {
294 self.candle_tensor().argmax_keepdim(dim)?
295 } else {
296 self.candle_tensor().argmax(dim)?
297 };
298
299 Ok(Tensor::from_candle(
300 result_candle,
301 crate::types::DataType::U32, self.layout(),
303 ))
304 }
305
306 pub fn argmin(&self, dim: usize, keep_dim: bool) -> Result<Tensor> {
308 let shape = self.shape();
309
310 if dim >= shape.len() {
311 return Err(anyhow!(
312 "Dimension {} is out of bounds for tensor with {} dimensions",
313 dim,
314 shape.len()
315 ));
316 }
317
318 let result_candle = if keep_dim {
319 self.candle_tensor().argmin_keepdim(dim)?
320 } else {
321 self.candle_tensor().argmin(dim)?
322 };
323
324 Ok(Tensor::from_candle(
325 result_candle,
326 crate::types::DataType::U32, self.layout(),
328 ))
329 }
330
331 pub fn count_nonzero(&self) -> Result<usize> {
333 let data = self.to_vec()?;
334 Ok(data.iter().filter(|&&x| x != 0.0).count())
335 }
336
337 pub fn count_nonzero_dim(&self, dim: usize) -> Result<Tensor> {
339 let shape = self.shape();
340
341 if dim >= shape.len() {
342 return Err(anyhow!(
343 "Dimension {} is out of bounds for tensor with {} dimensions",
344 dim,
345 shape.len()
346 ));
347 }
348
349 let _abs_values = self.abs()?;
351 let epsilon = 1e-7;
352 let _epsilon_tensor =
353 Tensor::from_data(vec![epsilon], vec![1], self.dtype(), self.layout())?;
354
355 let dim_size = shape[dim];
358 let mut output_shape = shape;
359 output_shape[dim] = 1;
360
361 let count_tensor = Tensor::from_data(
362 vec![dim_size as f32],
363 output_shape,
364 self.dtype(),
365 self.layout(),
366 )?;
367 Ok(count_tensor)
368 }
369
370 pub fn cumsum(&self, dim: usize) -> Result<Tensor> {
372 let shape = self.shape();
373
374 if dim >= shape.len() {
375 return Err(anyhow!(
376 "Dimension {} is out of bounds for tensor with {} dimensions",
377 dim,
378 shape.len()
379 ));
380 }
381
382 let data = self.to_vec()?;
385 let mut cumsum_data = Vec::with_capacity(data.len());
386 let mut running_sum = 0.0;
387
388 for &value in &data {
389 running_sum += value;
390 cumsum_data.push(running_sum);
391 }
392
393 Ok(Tensor::from_data(
394 cumsum_data,
395 shape,
396 self.dtype(),
397 self.layout(),
398 )?)
399 }
400
401 pub fn cumprod(&self, dim: usize) -> Result<Tensor> {
403 let shape = self.shape();
404
405 if dim >= shape.len() {
406 return Err(anyhow!(
407 "Dimension {} is out of bounds for tensor with {} dimensions",
408 dim,
409 shape.len()
410 ));
411 }
412
413 let data = self.to_vec()?;
415 let mut cumprod_data = Vec::with_capacity(data.len());
416 let mut running_prod = 1.0;
417
418 for &value in &data {
419 running_prod *= value;
420 cumprod_data.push(running_prod);
421 }
422
423 Ok(Tensor::from_data(
424 cumprod_data,
425 shape,
426 self.dtype(),
427 self.layout(),
428 )?)
429 }
430
431 pub fn softmax(&self, dim: usize) -> Result<Tensor> {
433 let shape = self.shape();
434
435 if dim >= shape.len() {
436 return Err(anyhow!(
437 "Dimension {} is out of bounds for tensor with {} dimensions",
438 dim,
439 shape.len()
440 ));
441 }
442
443 let max_vals = self.max_dim(dim, true)?;
446 let shifted = self.sub(&max_vals)?;
447 let exp_vals = shifted.exp()?;
448 let sum_exp = exp_vals.sum_dim(dim, true)?;
449 exp_vals.div(&sum_exp)
450 }
451
452 pub fn log_softmax(&self, dim: usize) -> Result<Tensor> {
454 let softmax_result = self.softmax(dim)?;
455 softmax_result.log()
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462 use crate::types::{DataType, TensorLayout};
463
464 #[test]
465 fn test_sum_operations() -> Result<()> {
466 let a = Tensor::from_data(
467 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
468 vec![2, 3],
469 DataType::F32,
470 TensorLayout::RowMajor,
471 )?;
472
473 let sum_all = a.sum_all()?;
475 let sum_all_data = sum_all.to_vec()?;
476 assert_eq!(sum_all_data[0], 21.0);
477
478 let sum_dim0 = a.sum_dim(0, false)?;
480 let sum_dim0_data = sum_dim0.to_vec()?;
481 assert_eq!(sum_dim0_data, vec![5.0, 7.0, 9.0]); let sum_dim1 = a.sum_dim(1, false)?;
485 let sum_dim1_data = sum_dim1.to_vec()?;
486 assert_eq!(sum_dim1_data, vec![6.0, 15.0]); Ok(())
489 }
490
491 #[test]
492 fn test_mean_operations() -> Result<()> {
493 let a = Tensor::from_data(
494 vec![2.0, 4.0, 6.0, 8.0],
495 vec![2, 2],
496 DataType::F32,
497 TensorLayout::RowMajor,
498 )?;
499
500 let mean_all = a.mean_all()?;
502 let mean_all_data = mean_all.to_vec()?;
503 assert_eq!(mean_all_data[0], 5.0);
504
505 let mean_dim0 = a.mean_dim(0, false)?;
507 let mean_dim0_data = mean_dim0.to_vec()?;
508 assert_eq!(mean_dim0_data, vec![4.0, 6.0]); Ok(())
511 }
512
513 #[test]
514 fn test_max_min_operations() -> Result<()> {
515 let a = Tensor::from_data(
516 vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0],
517 vec![2, 3],
518 DataType::F32,
519 TensorLayout::RowMajor,
520 )?;
521
522 let max_all = a.max_all()?;
524 let max_all_data = max_all.to_vec()?;
525 assert_eq!(max_all_data[0], 9.0);
526
527 let min_all = a.min_all()?;
529 let min_all_data = min_all.to_vec()?;
530 assert_eq!(min_all_data[0], 1.0);
531
532 Ok(())
533 }
534
535 #[test]
536 fn test_norm_operations() -> Result<()> {
537 let a = Tensor::from_data(
538 vec![3.0, 4.0],
539 vec![2],
540 DataType::F32,
541 TensorLayout::RowMajor,
542 )?;
543
544 let l2_norm = a.norm()?;
546 let l2_norm_data = l2_norm.to_vec()?;
547 assert_eq!(l2_norm_data[0], 5.0); let l1_norm = a.norm_p(1.0)?;
551 let l1_norm_data = l1_norm.to_vec()?;
552 assert_eq!(l1_norm_data[0], 7.0); Ok(())
555 }
556
557 #[test]
558 fn test_variance_std() -> Result<()> {
559 let a = Tensor::from_data(
560 vec![1.0, 2.0, 3.0, 4.0, 5.0],
561 vec![5],
562 DataType::F32,
563 TensorLayout::RowMajor,
564 )?;
565
566 let variance = a.var_all()?;
567 let std = a.std_all()?;
568
569 let var_data = variance.to_vec()?;
570 let std_data = std.to_vec()?;
571
572 assert!((var_data[0] - 2.0).abs() < 1e-6);
574 assert!((std_data[0] - 1.4142135).abs() < 1e-6);
575
576 Ok(())
577 }
578
579 #[test]
580 fn test_softmax() -> Result<()> {
581 let a = Tensor::from_data(
582 vec![1.0, 2.0, 3.0],
583 vec![3],
584 DataType::F32,
585 TensorLayout::RowMajor,
586 )?;
587
588 let softmax_result = a.softmax(0)?;
589 let softmax_data = softmax_result.to_vec()?;
590
591 let sum: f32 = softmax_data.iter().sum();
593 assert!((sum - 1.0).abs() < 1e-6);
594
595 assert!(softmax_data.iter().all(|&x| x > 0.0));
597
598 Ok(())
599 }
600
601 #[test]
602 fn test_argmax_argmin() -> Result<()> {
603 let a = Tensor::from_data(
604 vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0],
605 vec![2, 3],
606 DataType::F32,
607 TensorLayout::RowMajor,
608 )?;
609
610 let argmax_dim1 = a.argmax(1, false)?;
611 let argmax_data = argmax_dim1.to_vec()?;
612
613 assert_eq!(argmax_data.len(), 2);
616
617 Ok(())
618 }
619
620 #[test]
621 fn test_cumulative_operations() -> Result<()> {
622 let a = Tensor::from_data(
623 vec![1.0, 2.0, 3.0, 4.0],
624 vec![4],
625 DataType::F32,
626 TensorLayout::RowMajor,
627 )?;
628
629 let cumsum = a.cumsum(0)?;
630 let cumsum_data = cumsum.to_vec()?;
631 assert_eq!(cumsum_data, vec![1.0, 3.0, 6.0, 10.0]);
632
633 let cumprod = a.cumprod(0)?;
634 let cumprod_data = cumprod.to_vec()?;
635 assert_eq!(cumprod_data, vec![1.0, 2.0, 6.0, 24.0]);
636
637 Ok(())
638 }
639
640 #[test]
641 fn test_error_handling() {
642 let a = Tensor::from_data(
643 vec![1.0, 2.0, 3.0, 4.0],
644 vec![2, 2],
645 DataType::F32,
646 TensorLayout::RowMajor,
647 )
648 .unwrap();
649
650 assert!(a.sum_dim(5, false).is_err());
652 assert!(a.max_dim(5, false).is_err());
653 assert!(a.argmax(5, false).is_err());
654
655 assert!(a.norm_p(-1.0).is_err());
657 assert!(a.norm_p(0.0).is_err());
658 }
659
660 #[test]
661 fn test_keep_dim() -> Result<()> {
662 let a = Tensor::from_data(
663 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
664 vec![2, 3],
665 DataType::F32,
666 TensorLayout::RowMajor,
667 )?;
668
669 let sum_keep = a.sum_dim(1, true)?;
671 assert_eq!(sum_keep.shape(), vec![2, 1]);
672
673 let sum_no_keep = a.sum_dim(1, false)?;
675 assert_eq!(sum_no_keep.shape(), vec![2]);
676
677 Ok(())
678 }
679}