use crate::tensor::Tensor;
use crate::buffer::Buffer;
use crate::shape::Shape;
use crate::device::Device;
use crate::errors::{EtensorError, EtensorResult};
pub fn sum_all(a: &Tensor) -> EtensorResult<Tensor> {
let slice = a.data.as_f32_slice()?;
let total_sum: f32 = slice.iter().sum();
let out_shape = Shape::new(vec![1]);
Ok(Tensor::new(
Buffer::from_f32_vec(vec![total_sum]),
out_shape,
Device::Cpu,
a.dtype,
false, ))
}
pub fn mean_all(a: &Tensor) -> EtensorResult<Tensor> {
let slice = a.data.as_f32_slice()?;
if slice.is_empty() {
return Err(EtensorError::InternalError(
"Cannot calculate the mean of an empty tensor.".to_string(),
));
}
let total_sum: f32 = slice.iter().sum();
let mean = total_sum / (slice.len() as f32);
let out_shape = Shape::new(vec![1]);
Ok(Tensor::new(
Buffer::from_f32_vec(vec![mean]),
out_shape,
Device::Cpu,
a.dtype,
false,
))
}
pub fn max_all(a: &Tensor) -> EtensorResult<Tensor> {
let slice = a.data.as_f32_slice()?;
if slice.is_empty() {
return Err(EtensorError::InternalError(
"Cannot calculate the max of an empty tensor.".to_string(),
));
}
let max_val = slice.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let out_shape = Shape::new(vec![1]);
Ok(Tensor::new(
Buffer::from_f32_vec(vec![max_val]),
out_shape,
Device::Cpu,
a.dtype,
false,
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtypes::DType;
fn make_test_tensor(data: Vec<f32>, dims: Vec<usize>) -> Tensor {
Tensor::new(
Buffer::from_f32_vec(data),
Shape::new(dims),
Device::Cpu,
DType::F32,
false,
)
}
#[test]
fn test_cpu_reduce_sum_all() {
let a = make_test_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let c = sum_all(&a).unwrap();
let slice = c.data.as_f32_slice().unwrap();
assert_eq!(c.shape.dims, vec![1]);
assert_eq!(slice, &[10.0]);
}
#[test]
fn test_cpu_reduce_mean_all() {
let a = make_test_tensor(vec![2.0, 4.0, 6.0, 8.0], vec![4]);
let c = mean_all(&a).unwrap();
let slice = c.data.as_f32_slice().unwrap();
assert_eq!(c.shape.dims, vec![1]);
assert_eq!(slice, &[5.0]);
}
#[test]
fn test_cpu_reduce_max_all() {
let a = make_test_tensor(vec![-5.0, 12.0, 3.0, 42.0, 0.0, -100.0], vec![2, 3, 1]);
let c = max_all(&a).unwrap();
let slice = c.data.as_f32_slice().unwrap();
assert_eq!(c.shape.dims, vec![1]);
assert_eq!(slice, &[42.0]);
}
}