use super::super::core::Tensor;
use crate::error::{RusTorchError, RusTorchResult};
use num_traits::Float;
impl<T: Float + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Tensor<T> {
pub fn item(&self) -> T {
if self.data.len() == 1 {
self.data[0]
} else {
panic!("item() can only be called on tensors with exactly one element")
}
}
pub fn mean_axis(&self, axis: usize) -> RusTorchResult<Self> {
let sum_result = self.sum_axis(axis)?;
let axis_size = T::from(self.shape()[axis]).unwrap_or(T::one());
Ok(sum_result.div_scalar(axis_size))
}
pub fn var(&self) -> T {
let mean = self.mean();
let squared_diffs: T = self
.data
.iter()
.map(|&x| {
let diff = x - mean;
diff * diff
})
.fold(T::zero(), |acc, x| acc + x);
let count = T::from(self.data.len()).unwrap_or(T::one());
squared_diffs / count
}
pub fn std(&self) -> T {
self.var().sqrt()
}
pub fn min(&self) -> T {
self.data
.iter()
.copied()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(T::zero())
}
pub fn max(&self) -> T {
self.data
.iter()
.copied()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(T::zero())
}
pub fn argmin(&self) -> usize {
self.data
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx)
.unwrap_or(0)
}
pub fn argmax(&self) -> usize {
self.data
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx)
.unwrap_or(0)
}
pub fn median(&self) -> T {
let mut sorted_data: Vec<T> = self.data.iter().copied().collect();
sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
let len = sorted_data.len();
if len % 2 == 1 {
sorted_data[len / 2]
} else {
let mid1 = sorted_data[len / 2 - 1];
let mid2 = sorted_data[len / 2];
(mid1 + mid2) / T::from(2.0).unwrap()
}
}
pub fn quantile(&self, q: f64) -> T {
if q < 0.0 || q > 1.0 {
panic!("Quantile must be between 0.0 and 1.0");
}
let mut sorted_data: Vec<T> = self.data.iter().copied().collect();
sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
let len = sorted_data.len();
let idx = (q * (len - 1) as f64) as usize;
if idx >= len {
sorted_data[len - 1]
} else {
sorted_data[idx]
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sum() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
assert_eq!(tensor.sum(), 10.0);
}
#[test]
fn test_mean() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
assert_eq!(tensor.mean(), 2.5);
}
#[test]
fn test_sum_axis() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let result = tensor.sum_axis(0).unwrap();
assert_eq!(result.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
let result = tensor.sum_axis(1).unwrap();
assert_eq!(result.as_slice().unwrap(), &[6.0, 15.0]); }
#[test]
fn test_var_std() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
let var = tensor.var();
let std = tensor.std();
assert!((var - 1.25).abs() < 0.001);
assert!((std - 1.118).abs() < 0.01);
}
#[test]
fn test_min_max() {
let tensor = Tensor::from_vec(vec![3.0, 1.0, 4.0, 2.0], vec![4]);
assert_eq!(tensor.min(), 1.0);
assert_eq!(tensor.max(), 4.0);
assert_eq!(tensor.argmin(), 1);
assert_eq!(tensor.argmax(), 2);
}
#[test]
fn test_median() {
let tensor = Tensor::from_vec(vec![3.0, 1.0, 4.0, 2.0], vec![4]);
assert_eq!(tensor.median(), 2.5);
let tensor_odd = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
assert_eq!(tensor_odd.median(), 2.0);
}
}