Skip to main content

ronn_core/ops/
reduction.rs

1//! Reduction operations for tensors.
2//!
3//! This module provides reduction operations that aggregate tensor values
4//! along specified dimensions, including sum, mean, max, min, and others.
5
6use crate::ops::arithmetic::ArithmeticOps;
7use crate::ops::shape::ShapeOps;
8use crate::tensor::Tensor;
9use anyhow::{Result, anyhow};
10
11/// Trait for reduction operations on tensors.
12pub trait ReductionOps {
13    /// Sum all elements in the tensor.
14    fn sum_all(&self) -> Result<Tensor>;
15
16    /// Sum along specified dimensions.
17    fn sum_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor>;
18
19    /// Mean of all elements in the tensor.
20    fn mean_all(&self) -> Result<Tensor>;
21
22    /// Mean along specified dimensions.
23    fn mean_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor>;
24
25    /// Maximum value in the tensor.
26    fn max_all(&self) -> Result<Tensor>;
27
28    /// Maximum along specified dimensions.
29    fn max_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor>;
30
31    /// Minimum value in the tensor.
32    fn min_all(&self) -> Result<Tensor>;
33
34    /// Minimum along specified dimensions.
35    fn min_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor>;
36
37    /// Product of all elements.
38    fn prod_all(&self) -> Result<Tensor>;
39
40    /// Standard deviation of all elements.
41    fn std_all(&self) -> Result<Tensor>;
42
43    /// Variance of all elements.
44    fn var_all(&self) -> Result<Tensor>;
45
46    /// L2 norm (Euclidean norm) of the tensor.
47    fn norm(&self) -> Result<Tensor>;
48
49    /// Lp norm of the tensor.
50    fn norm_p(&self, p: f32) -> Result<Tensor>;
51}
52
53impl ReductionOps for Tensor {
54    fn sum_all(&self) -> Result<Tensor> {
55        let result_candle = self.candle_tensor().sum_all()?;
56
57        // Ensure result is at least 1D
58        let reshaped = if result_candle.dims().is_empty() {
59            result_candle.reshape(&[1])?
60        } else {
61            result_candle
62        };
63
64        Ok(Tensor::from_candle(reshaped, self.dtype(), self.layout()))
65    }
66
67    fn sum_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor> {
68        let shape = self.shape();
69
70        // Validate dimensions
71        for &dim in dims {
72            if dim >= shape.len() {
73                return Err(anyhow!(
74                    "Dimension {} is out of bounds for tensor with {} dimensions",
75                    dim,
76                    shape.len()
77                ));
78            }
79        }
80
81        let result_candle = if keep_dim {
82            self.candle_tensor().sum_keepdim(dims)?
83        } else {
84            self.candle_tensor().sum(dims)?
85        };
86
87        Ok(Tensor::from_candle(
88            result_candle,
89            self.dtype(),
90            self.layout(),
91        ))
92    }
93
94    fn mean_all(&self) -> Result<Tensor> {
95        let sum = self.sum_all()?;
96        let num_elements = self.numel() as f32;
97        sum.div_scalar(num_elements)
98    }
99
100    fn mean_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor> {
101        let sum = self.sum_dims(dims, keep_dim)?;
102
103        // Calculate the number of elements being reduced over
104        let shape = self.shape();
105        let reduction_size: usize = dims.iter().map(|&dim| shape[dim]).product();
106
107        sum.div_scalar(reduction_size as f32)
108    }
109
110    fn max_all(&self) -> Result<Tensor> {
111        let flattened = self.flatten()?;
112        let result_candle = flattened.candle_tensor().max(0)?;
113
114        // Ensure result is at least 1D
115        let reshaped = if result_candle.dims().is_empty() {
116            result_candle.reshape(&[1])?
117        } else {
118            result_candle
119        };
120
121        Ok(Tensor::from_candle(reshaped, self.dtype(), self.layout()))
122    }
123
124    fn max_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor> {
125        let shape = self.shape();
126
127        // Validate dimensions
128        for &dim in dims {
129            if dim >= shape.len() {
130                return Err(anyhow!(
131                    "Dimension {} is out of bounds for tensor with {} dimensions",
132                    dim,
133                    shape.len()
134                ));
135            }
136        }
137
138        // For simplicity, we'll reduce one dimension at a time
139        let mut result = self.clone();
140        let mut sorted_dims = dims.to_vec();
141        sorted_dims.sort_unstable();
142        sorted_dims.reverse(); // Process in reverse order to maintain indices
143
144        for &dim in &sorted_dims {
145            let result_candle = if keep_dim {
146                result.candle_tensor().max_keepdim(dim)?
147            } else {
148                result.candle_tensor().max(dim)?
149            };
150            result = Tensor::from_candle(result_candle, result.dtype(), result.layout());
151        }
152
153        Ok(result)
154    }
155
156    fn min_all(&self) -> Result<Tensor> {
157        let flattened = self.flatten()?;
158        let result_candle = flattened.candle_tensor().min(0)?;
159
160        // Ensure result is at least 1D
161        let reshaped = if result_candle.dims().is_empty() {
162            result_candle.reshape(&[1])?
163        } else {
164            result_candle
165        };
166
167        Ok(Tensor::from_candle(reshaped, self.dtype(), self.layout()))
168    }
169
170    fn min_dims(&self, dims: &[usize], keep_dim: bool) -> Result<Tensor> {
171        let shape = self.shape();
172
173        // Validate dimensions
174        for &dim in dims {
175            if dim >= shape.len() {
176                return Err(anyhow!(
177                    "Dimension {} is out of bounds for tensor with {} dimensions",
178                    dim,
179                    shape.len()
180                ));
181            }
182        }
183
184        // For simplicity, we'll reduce one dimension at a time
185        let mut result = self.clone();
186        let mut sorted_dims = dims.to_vec();
187        sorted_dims.sort_unstable();
188        sorted_dims.reverse(); // Process in reverse order to maintain indices
189
190        for &dim in &sorted_dims {
191            let result_candle = if keep_dim {
192                result.candle_tensor().min_keepdim(dim)?
193            } else {
194                result.candle_tensor().min(dim)?
195            };
196            result = Tensor::from_candle(result_candle, result.dtype(), result.layout());
197        }
198
199        Ok(result)
200    }
201
202    fn prod_all(&self) -> Result<Tensor> {
203        // Product of all elements - we'll use a simple implementation
204        let data = self.to_vec()?;
205        let product = data.iter().fold(1.0, |acc, &x| acc * x);
206
207        Ok(Tensor::from_data(
208            vec![product],
209            vec![1],
210            self.dtype(),
211            self.layout(),
212        )?)
213    }
214
215    fn std_all(&self) -> Result<Tensor> {
216        let variance = self.var_all()?;
217        variance.sqrt()
218    }
219
220    fn var_all(&self) -> Result<Tensor> {
221        let mean = self.mean_all()?;
222        let diff = self.sub(&mean)?;
223        let squared_diff = diff.mul(&diff)?;
224        squared_diff.mean_all()
225    }
226
227    fn norm(&self) -> Result<Tensor> {
228        self.norm_p(2.0)
229    }
230
231    fn norm_p(&self, p: f32) -> Result<Tensor> {
232        if p <= 0.0 {
233            return Err(anyhow!("Norm p must be positive, got {}", p));
234        }
235
236        if p == 1.0 {
237            // L1 norm: sum of absolute values
238            let abs_values = self.abs()?;
239            abs_values.sum_all()
240        } else if p == 2.0 {
241            // L2 norm: sqrt of sum of squares
242            let squared = self.mul(self)?;
243            let sum_squared = squared.sum_all()?;
244            sum_squared.sqrt()
245        } else if p.is_infinite() {
246            // Lāˆž norm: maximum absolute value
247            let abs_values = self.abs()?;
248            abs_values.max_all()
249        } else {
250            // General Lp norm: (sum of |x|^p)^(1/p)
251            let abs_values = self.abs()?;
252            let powered = abs_values.pow(p)?;
253            let sum_powered = powered.sum_all()?;
254            sum_powered.pow(1.0 / p)
255        }
256    }
257}
258
259/// Additional reduction methods for convenience.
260impl Tensor {
261    /// Sum along a single dimension.
262    pub fn sum_dim(&self, dim: usize, keep_dim: bool) -> Result<Tensor> {
263        self.sum_dims(&[dim], keep_dim)
264    }
265
266    /// Mean along a single dimension.
267    pub fn mean_dim(&self, dim: usize, keep_dim: bool) -> Result<Tensor> {
268        self.mean_dims(&[dim], keep_dim)
269    }
270
271    /// Maximum along a single dimension.
272    pub fn max_dim(&self, dim: usize, keep_dim: bool) -> Result<Tensor> {
273        self.max_dims(&[dim], keep_dim)
274    }
275
276    /// Minimum along a single dimension.
277    pub fn min_dim(&self, dim: usize, keep_dim: bool) -> Result<Tensor> {
278        self.min_dims(&[dim], keep_dim)
279    }
280
281    /// Find indices of maximum values along a dimension.
282    pub fn argmax(&self, dim: usize, keep_dim: bool) -> Result<Tensor> {
283        let shape = self.shape();
284
285        if dim >= shape.len() {
286            return Err(anyhow!(
287                "Dimension {} is out of bounds for tensor with {} dimensions",
288                dim,
289                shape.len()
290            ));
291        }
292
293        let result_candle = if keep_dim {
294            self.candle_tensor().argmax_keepdim(dim)?
295        } else {
296            self.candle_tensor().argmax(dim)?
297        };
298
299        Ok(Tensor::from_candle(
300            result_candle,
301            crate::types::DataType::U32, // Candle returns indices as U32
302            self.layout(),
303        ))
304    }
305
306    /// Find indices of minimum values along a dimension.
307    pub fn argmin(&self, dim: usize, keep_dim: bool) -> Result<Tensor> {
308        let shape = self.shape();
309
310        if dim >= shape.len() {
311            return Err(anyhow!(
312                "Dimension {} is out of bounds for tensor with {} dimensions",
313                dim,
314                shape.len()
315            ));
316        }
317
318        let result_candle = if keep_dim {
319            self.candle_tensor().argmin_keepdim(dim)?
320        } else {
321            self.candle_tensor().argmin(dim)?
322        };
323
324        Ok(Tensor::from_candle(
325            result_candle,
326            crate::types::DataType::U32, // Candle returns indices as U32
327            self.layout(),
328        ))
329    }
330
331    /// Count non-zero elements.
332    pub fn count_nonzero(&self) -> Result<usize> {
333        let data = self.to_vec()?;
334        Ok(data.iter().filter(|&&x| x != 0.0).count())
335    }
336
337    /// Count elements along a dimension.
338    pub fn count_nonzero_dim(&self, dim: usize) -> Result<Tensor> {
339        let shape = self.shape();
340
341        if dim >= shape.len() {
342            return Err(anyhow!(
343                "Dimension {} is out of bounds for tensor with {} dimensions",
344                dim,
345                shape.len()
346            ));
347        }
348
349        // Create a mask for non-zero elements
350        let _abs_values = self.abs()?;
351        let epsilon = 1e-7;
352        let _epsilon_tensor =
353            Tensor::from_data(vec![epsilon], vec![1], self.dtype(), self.layout())?;
354
355        // This is a simplified implementation - in practice we'd need proper comparison operations
356        // For now, we'll use a placeholder that counts all elements
357        let dim_size = shape[dim];
358        let mut output_shape = shape;
359        output_shape[dim] = 1;
360
361        let count_tensor = Tensor::from_data(
362            vec![dim_size as f32],
363            output_shape,
364            self.dtype(),
365            self.layout(),
366        )?;
367        Ok(count_tensor)
368    }
369
370    /// Cumulative sum along a dimension.
371    pub fn cumsum(&self, dim: usize) -> Result<Tensor> {
372        let shape = self.shape();
373
374        if dim >= shape.len() {
375            return Err(anyhow!(
376                "Dimension {} is out of bounds for tensor with {} dimensions",
377                dim,
378                shape.len()
379            ));
380        }
381
382        // This is a placeholder implementation
383        // A full implementation would require more complex indexing
384        let data = self.to_vec()?;
385        let mut cumsum_data = Vec::with_capacity(data.len());
386        let mut running_sum = 0.0;
387
388        for &value in &data {
389            running_sum += value;
390            cumsum_data.push(running_sum);
391        }
392
393        Ok(Tensor::from_data(
394            cumsum_data,
395            shape,
396            self.dtype(),
397            self.layout(),
398        )?)
399    }
400
401    /// Cumulative product along a dimension.
402    pub fn cumprod(&self, dim: usize) -> Result<Tensor> {
403        let shape = self.shape();
404
405        if dim >= shape.len() {
406            return Err(anyhow!(
407                "Dimension {} is out of bounds for tensor with {} dimensions",
408                dim,
409                shape.len()
410            ));
411        }
412
413        // This is a placeholder implementation
414        let data = self.to_vec()?;
415        let mut cumprod_data = Vec::with_capacity(data.len());
416        let mut running_prod = 1.0;
417
418        for &value in &data {
419            running_prod *= value;
420            cumprod_data.push(running_prod);
421        }
422
423        Ok(Tensor::from_data(
424            cumprod_data,
425            shape,
426            self.dtype(),
427            self.layout(),
428        )?)
429    }
430
431    /// Softmax operation along a dimension.
432    pub fn softmax(&self, dim: usize) -> Result<Tensor> {
433        let shape = self.shape();
434
435        if dim >= shape.len() {
436            return Err(anyhow!(
437                "Dimension {} is out of bounds for tensor with {} dimensions",
438                dim,
439                shape.len()
440            ));
441        }
442
443        // Softmax: exp(x) / sum(exp(x))
444        // For numerical stability: exp(x - max(x)) / sum(exp(x - max(x)))
445        let max_vals = self.max_dim(dim, true)?;
446        let shifted = self.sub(&max_vals)?;
447        let exp_vals = shifted.exp()?;
448        let sum_exp = exp_vals.sum_dim(dim, true)?;
449        exp_vals.div(&sum_exp)
450    }
451
452    /// Log softmax operation along a dimension.
453    pub fn log_softmax(&self, dim: usize) -> Result<Tensor> {
454        let softmax_result = self.softmax(dim)?;
455        softmax_result.log()
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462    use crate::types::{DataType, TensorLayout};
463
464    #[test]
465    fn test_sum_operations() -> Result<()> {
466        let a = Tensor::from_data(
467            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
468            vec![2, 3],
469            DataType::F32,
470            TensorLayout::RowMajor,
471        )?;
472
473        // Sum all elements
474        let sum_all = a.sum_all()?;
475        let sum_all_data = sum_all.to_vec()?;
476        assert_eq!(sum_all_data[0], 21.0);
477
478        // Sum along dimension 0 (rows)
479        let sum_dim0 = a.sum_dim(0, false)?;
480        let sum_dim0_data = sum_dim0.to_vec()?;
481        assert_eq!(sum_dim0_data, vec![5.0, 7.0, 9.0]); // [1+4, 2+5, 3+6]
482
483        // Sum along dimension 1 (columns)
484        let sum_dim1 = a.sum_dim(1, false)?;
485        let sum_dim1_data = sum_dim1.to_vec()?;
486        assert_eq!(sum_dim1_data, vec![6.0, 15.0]); // [1+2+3, 4+5+6]
487
488        Ok(())
489    }
490
491    #[test]
492    fn test_mean_operations() -> Result<()> {
493        let a = Tensor::from_data(
494            vec![2.0, 4.0, 6.0, 8.0],
495            vec![2, 2],
496            DataType::F32,
497            TensorLayout::RowMajor,
498        )?;
499
500        // Mean of all elements
501        let mean_all = a.mean_all()?;
502        let mean_all_data = mean_all.to_vec()?;
503        assert_eq!(mean_all_data[0], 5.0);
504
505        // Mean along dimension 0
506        let mean_dim0 = a.mean_dim(0, false)?;
507        let mean_dim0_data = mean_dim0.to_vec()?;
508        assert_eq!(mean_dim0_data, vec![4.0, 6.0]); // [(2+6)/2, (4+8)/2]
509
510        Ok(())
511    }
512
513    #[test]
514    fn test_max_min_operations() -> Result<()> {
515        let a = Tensor::from_data(
516            vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0],
517            vec![2, 3],
518            DataType::F32,
519            TensorLayout::RowMajor,
520        )?;
521
522        // Max of all elements
523        let max_all = a.max_all()?;
524        let max_all_data = max_all.to_vec()?;
525        assert_eq!(max_all_data[0], 9.0);
526
527        // Min of all elements
528        let min_all = a.min_all()?;
529        let min_all_data = min_all.to_vec()?;
530        assert_eq!(min_all_data[0], 1.0);
531
532        Ok(())
533    }
534
535    #[test]
536    fn test_norm_operations() -> Result<()> {
537        let a = Tensor::from_data(
538            vec![3.0, 4.0],
539            vec![2],
540            DataType::F32,
541            TensorLayout::RowMajor,
542        )?;
543
544        // L2 norm
545        let l2_norm = a.norm()?;
546        let l2_norm_data = l2_norm.to_vec()?;
547        assert_eq!(l2_norm_data[0], 5.0); // sqrt(3^2 + 4^2) = 5
548
549        // L1 norm
550        let l1_norm = a.norm_p(1.0)?;
551        let l1_norm_data = l1_norm.to_vec()?;
552        assert_eq!(l1_norm_data[0], 7.0); // |3| + |4| = 7
553
554        Ok(())
555    }
556
557    #[test]
558    fn test_variance_std() -> Result<()> {
559        let a = Tensor::from_data(
560            vec![1.0, 2.0, 3.0, 4.0, 5.0],
561            vec![5],
562            DataType::F32,
563            TensorLayout::RowMajor,
564        )?;
565
566        let variance = a.var_all()?;
567        let std = a.std_all()?;
568
569        let var_data = variance.to_vec()?;
570        let std_data = std.to_vec()?;
571
572        // For [1,2,3,4,5], mean = 3, variance = 2, std = sqrt(2) ā‰ˆ 1.414
573        assert!((var_data[0] - 2.0).abs() < 1e-6);
574        assert!((std_data[0] - 1.4142135).abs() < 1e-6);
575
576        Ok(())
577    }
578
579    #[test]
580    fn test_softmax() -> Result<()> {
581        let a = Tensor::from_data(
582            vec![1.0, 2.0, 3.0],
583            vec![3],
584            DataType::F32,
585            TensorLayout::RowMajor,
586        )?;
587
588        let softmax_result = a.softmax(0)?;
589        let softmax_data = softmax_result.to_vec()?;
590
591        // Sum of softmax should be 1
592        let sum: f32 = softmax_data.iter().sum();
593        assert!((sum - 1.0).abs() < 1e-6);
594
595        // All values should be positive
596        assert!(softmax_data.iter().all(|&x| x > 0.0));
597
598        Ok(())
599    }
600
601    #[test]
602    fn test_argmax_argmin() -> Result<()> {
603        let a = Tensor::from_data(
604            vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0],
605            vec![2, 3],
606            DataType::F32,
607            TensorLayout::RowMajor,
608        )?;
609
610        let argmax_dim1 = a.argmax(1, false)?;
611        let argmax_data = argmax_dim1.to_vec()?;
612
613        // argmax along dim 1: [2, 2] (indices of max in each row)
614        // Note: the actual values depend on how Candle implements argmax
615        assert_eq!(argmax_data.len(), 2);
616
617        Ok(())
618    }
619
620    #[test]
621    fn test_cumulative_operations() -> Result<()> {
622        let a = Tensor::from_data(
623            vec![1.0, 2.0, 3.0, 4.0],
624            vec![4],
625            DataType::F32,
626            TensorLayout::RowMajor,
627        )?;
628
629        let cumsum = a.cumsum(0)?;
630        let cumsum_data = cumsum.to_vec()?;
631        assert_eq!(cumsum_data, vec![1.0, 3.0, 6.0, 10.0]);
632
633        let cumprod = a.cumprod(0)?;
634        let cumprod_data = cumprod.to_vec()?;
635        assert_eq!(cumprod_data, vec![1.0, 2.0, 6.0, 24.0]);
636
637        Ok(())
638    }
639
640    #[test]
641    fn test_error_handling() {
642        let a = Tensor::from_data(
643            vec![1.0, 2.0, 3.0, 4.0],
644            vec![2, 2],
645            DataType::F32,
646            TensorLayout::RowMajor,
647        )
648        .unwrap();
649
650        // Out of bounds dimension
651        assert!(a.sum_dim(5, false).is_err());
652        assert!(a.max_dim(5, false).is_err());
653        assert!(a.argmax(5, false).is_err());
654
655        // Invalid norm p
656        assert!(a.norm_p(-1.0).is_err());
657        assert!(a.norm_p(0.0).is_err());
658    }
659
660    #[test]
661    fn test_keep_dim() -> Result<()> {
662        let a = Tensor::from_data(
663            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
664            vec![2, 3],
665            DataType::F32,
666            TensorLayout::RowMajor,
667        )?;
668
669        // Sum with keep_dim=true
670        let sum_keep = a.sum_dim(1, true)?;
671        assert_eq!(sum_keep.shape(), vec![2, 1]);
672
673        // Sum with keep_dim=false
674        let sum_no_keep = a.sum_dim(1, false)?;
675        assert_eq!(sum_no_keep.shape(), vec![2]);
676
677        Ok(())
678    }
679}