use std::{
borrow::Cow,
ops::{Add, Div, Mul, Sub},
};
use num_traits::{FromPrimitive, One, Zero};
use crate::Array;
impl<'a, T: Clone + Ord, const D: usize> Array<'a, T, D> {
pub fn max(&self) -> Option<T> {
self.flat().max().cloned()
}
pub fn arg_max(&self) -> Vec<usize> {
let mut positions = vec![];
if let Some(max) = self.max() {
for (index, value) in self.flat().enumerate() {
if value == &max {
positions.push(index)
}
}
}
positions
}
pub fn max_across(&self, axis: usize) -> Vec<Option<T>> {
self.axis_view(axis).map(|view| view.max()).collect()
}
pub fn arg_max_across(&self, axis: usize) -> Vec<Option<usize>> {
self.axis_view(axis)
.map(|view| view.arg_max().get(0).copied())
.collect()
}
pub fn min(&self) -> Option<T> {
self.flat().min().cloned()
}
pub fn arg_min(&self) -> Vec<usize> {
let mut positions = vec![];
if let Some(min) = self.min() {
for (index, value) in self.flat().enumerate() {
if value == &min {
positions.push(index)
}
}
}
positions
}
pub fn min_across(&self, axis: usize) -> Vec<Option<T>> {
self.axis_view(axis).map(|view| view.min()).collect()
}
pub fn arg_min_across(&self, axis: usize) -> Vec<Option<usize>> {
self.axis_view(axis)
.map(|view| view.arg_min().get(0).copied())
.collect()
}
pub fn clip(&self, min: &T, max: &T) -> Array<'a, T, D> {
let vec: Vec<T> = self
.vec
.iter()
.map(|val| val.clamp(min, max).clone())
.collect();
let shape = self.shape.clone();
let strides = self.strides.clone();
let idx_maps = self.idx_maps.clone();
Array {
vec: Cow::from(vec),
shape,
strides,
idx_maps,
}
}
}
impl<'a, T, const D: usize> Array<'a, T, D>
where
T: Clone + Ord + Sub<Output = T>,
{
pub fn ptp(&self) -> Option<T> {
self.max().and_then(|max| self.min().map(|min| max - min))
}
pub fn ptp_across(&self, axis: usize) -> Vec<Option<T>> {
self.axis_view(axis).map(|view| view.ptp()).collect()
}
}
impl<'a, T, const D: usize> Array<'a, T, D>
where
T: Clone + Add<Output = T> + Zero,
{
pub fn sum(&self) -> T {
self.flat().fold(T::zero(), |acc, val| acc + val.clone())
}
pub fn sum_across(&self, axis: usize) -> Vec<T> {
self.axis_view(axis).map(|view| view.sum()).collect()
}
}
impl<'a, T, const D: usize> Array<'a, T, D>
where
T: Clone + Mul<Output = T> + One,
{
pub fn prod(&self) -> T {
self.flat().fold(T::one(), |acc, val| acc * val.clone())
}
pub fn prod_across(&self, axis: usize) -> Vec<T> {
self.axis_view(axis).map(|view| view.prod()).collect()
}
}
impl<'a, T, const D: usize> Array<'a, T, D>
where
T: Clone + Add<Output = T> + FromPrimitive + Div<T, Output = T> + Zero,
{
pub fn mean(&self) -> T {
self.sum() / T::from_usize(self.shape().iter().product()).unwrap()
}
pub fn mean_across(&self, axis: usize) -> Vec<T> {
self.axis_view(axis).map(|view| view.mean()).collect()
}
}
impl<'a, T, const D: usize> Array<'a, T, D>
where
T: Clone + Sub<Output = T> + FromPrimitive + Div<T, Output = T> + Mul<Output = T> + Zero,
{
pub fn var(&self) -> T {
let mean = self.mean();
self.flat().fold(T::zero(), |acc, val| {
acc + (val.clone() - mean.clone()) * (val.clone() - mean.clone())
}) / T::from_usize(self.shape().iter().product()).unwrap()
}
pub fn var_across(&self, axis: usize) -> Vec<T> {
self.axis_view(axis).map(|view| view.var()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn max() {
let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
assert_eq!(array.max().unwrap(), 3);
}
#[test]
fn arg_max() {
let array = Array::init(vec![0, 1, 2 , 3], [2, 2]);
assert_eq!(array.arg_max()[0], 3);
}
#[test]
fn max_across() {
let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
assert_eq!(array.max_across(1), vec![Some(2), Some(3)]);
assert_eq!(array.max_across(0), vec![Some(1), Some(3)]);
}
#[test]
fn arg_max_across() {
let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
assert_eq!(array.arg_max_across(1), vec![Some(1), Some(1)]);
assert_eq!(array.arg_max_across(0), vec![Some(1), Some(1)]);
}
#[test]
fn min() {
let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
assert_eq!(array.min().unwrap(), 0);
}
#[test]
fn arg_min() {
let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
assert_eq!(array.arg_min()[0], 0);
}
#[test]
fn min_across() {
let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
assert_eq!(array.min_across(1), vec![Some(0), Some(1)]);
assert_eq!(array.min_across(0), vec![Some(0), Some(2)]);
}
#[test]
fn arg_min_across() {
let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
assert_eq!(array.arg_min_across(1), vec![Some(0), Some(0)]);
assert_eq!(array.arg_min_across(0), vec![Some(0), Some(0)]);
}
#[test]
fn clip() {
let array = Array::arange(0..10);
let clipped = array.clip(&1, &8);
assert_eq!(
clipped.flat().copied().collect::<Vec<i32>>(),
vec![1, 1, 2, 3, 4, 5, 6, 7, 8, 8]
);
}
#[test]
fn ptp() {
let array = Array::init(vec![4, 9, 2, 10, 6, 9, 7, 12], [2, 4]);
assert_eq!(array.ptp().unwrap(), 10)
}
#[test]
fn ptp_across() {
let array = Array::init(vec![4, 9, 2, 10, 6, 9, 7, 12], [2, 4]);
assert_eq!(array.ptp_across(0), vec![Some(8), Some(6)]);
assert_eq!(
array.ptp_across(1),
vec![Some(2), Some(0), Some(5), Some(2)]
)
}
#[test]
fn sum() {
let array = Array::arange(1..5).reshape([2, 2]);
assert_eq!(array.sum(), 10);
}
#[test]
fn sum_across() {
let array = Array::arange(1..5).reshape([2, 2]);
assert_eq!(array.sum_across(0), vec![3, 7]);
assert_eq!(array.sum_across(1), vec![4, 6]);
}
#[test]
fn prod() {
let array = Array::arange(1..5).reshape([2, 2]);
assert_eq!(array.prod(), 24);
}
#[test]
fn prod_across() {
let array = Array::arange(1..5).reshape([2, 2]);
assert_eq!(array.prod_across(0), vec![2, 12]);
assert_eq!(array.prod_across(1), vec![3, 8]);
}
#[test]
fn mean() {
let array = Array::arange(1..5).reshape([2, 2]);
assert_eq!(array.mean(), 2);
}
#[test]
fn mean_across() {
let array = Array::arange(1..5).reshape([2, 2]);
assert_eq!(array.mean_across(0), vec![1, 3]);
assert_eq!(array.mean_across(1), vec![2, 3]);
}
#[test]
fn var() {
let array = Array::init(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
assert_eq!(array.var(), 1.25);
}
#[test]
fn var_across() {
let array = Array::init(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
assert_eq!(array.var_across(0), vec![0.25, 0.25]);
assert_eq!(array.var_across(1), vec![1.0, 1.0]);
}
}