use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{Float, Zero};
use std::ops::{Add, Mul};
use super::{cumprod, cumsum};
pub fn amax<T>(array: &Array<T>, axis: Option<isize>, keepdims: bool) -> Result<Array<T>>
where
T: PartialOrd + Clone + Zero,
{
max(array, axis, keepdims)
}
pub fn amin<T>(array: &Array<T>, axis: Option<isize>, keepdims: bool) -> Result<Array<T>>
where
T: PartialOrd + Clone + Zero,
{
min(array, axis, keepdims)
}
pub fn max<T>(array: &Array<T>, axis: Option<isize>, keepdims: bool) -> Result<Array<T>>
where
T: PartialOrd + Clone + Zero,
{
if array.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"Cannot find max of empty array".to_string(),
));
}
match axis {
None => {
let data = array.to_vec();
let max_val =
data.iter().skip(1).fold(
data[0].clone(),
|max, x| {
if x > &max {
x.clone()
} else {
max
}
},
);
if keepdims {
let shape = vec![1; array.ndim()];
Ok(Array::from_vec(vec![max_val]).reshape(&shape))
} else {
Ok(Array::from_vec(vec![max_val]))
}
}
Some(ax) => {
let axis = if ax < 0 {
(array.ndim() as isize + ax) as usize
} else {
ax as usize
};
if axis >= array.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis,
array.ndim()
)));
}
let shape = array.shape();
let axis_size = shape[axis];
let mut out_shape = shape.clone();
if keepdims {
out_shape[axis] = 1;
} else {
out_shape.remove(axis);
}
if out_shape.is_empty() {
out_shape.push(1);
}
let out_size: usize = out_shape.iter().product();
let mut result_data = vec![T::zero(); out_size];
let mut strides = vec![1; array.ndim()];
for i in (0..array.ndim() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
for out_idx in 0..out_size {
let mut indices = vec![0; array.ndim()];
let mut temp = out_idx;
for i in 0..array.ndim() {
if i < axis {
let dim_size = shape[i];
indices[i] = temp % dim_size;
temp /= dim_size;
} else if i > axis || (i == axis && keepdims) {
let dim_idx = if keepdims { i } else { i - 1 };
if dim_idx < out_shape.len() {
let dim_size = out_shape[dim_idx];
indices[i] = temp % dim_size;
temp /= dim_size;
}
}
}
let mut max_val = None;
for j in 0..axis_size {
indices[axis] = j;
let val = array.get(&indices)?;
if max_val.as_ref().is_none_or(|mv| &val > mv) {
max_val = Some(val);
}
}
result_data[out_idx] = max_val.expect("max_val should be set when axis_size > 0");
}
Ok(Array::from_vec(result_data).reshape(&out_shape))
}
}
}
pub fn min<T>(array: &Array<T>, axis: Option<isize>, keepdims: bool) -> Result<Array<T>>
where
T: PartialOrd + Clone + Zero,
{
if array.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"Cannot find min of empty array".to_string(),
));
}
match axis {
None => {
let data = array.to_vec();
let min_val =
data.iter().skip(1).fold(
data[0].clone(),
|min, x| {
if x < &min {
x.clone()
} else {
min
}
},
);
if keepdims {
let shape = vec![1; array.ndim()];
Ok(Array::from_vec(vec![min_val]).reshape(&shape))
} else {
Ok(Array::from_vec(vec![min_val]))
}
}
Some(ax) => {
let axis = if ax < 0 {
(array.ndim() as isize + ax) as usize
} else {
ax as usize
};
if axis >= array.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis,
array.ndim()
)));
}
let shape = array.shape();
let axis_size = shape[axis];
let mut out_shape = shape.clone();
if keepdims {
out_shape[axis] = 1;
} else {
out_shape.remove(axis);
}
if out_shape.is_empty() {
out_shape.push(1);
}
let out_size: usize = out_shape.iter().product();
let mut result_data = vec![T::zero(); out_size];
let mut strides = vec![1; array.ndim()];
for i in (0..array.ndim() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
for out_idx in 0..out_size {
let mut indices = vec![0; array.ndim()];
let mut temp = out_idx;
for i in 0..array.ndim() {
if i < axis {
let dim_size = shape[i];
indices[i] = temp % dim_size;
temp /= dim_size;
} else if i > axis || (i == axis && keepdims) {
let dim_idx = if keepdims { i } else { i - 1 };
if dim_idx < out_shape.len() {
let dim_size = out_shape[dim_idx];
indices[i] = temp % dim_size;
temp /= dim_size;
}
}
}
let mut min_val = None;
for j in 0..axis_size {
indices[axis] = j;
let val = array.get(&indices)?;
if min_val.as_ref().is_none_or(|mv| &val < mv) {
min_val = Some(val);
}
}
result_data[out_idx] = min_val.expect("min_val should be set when axis_size > 0");
}
Ok(Array::from_vec(result_data).reshape(&out_shape))
}
}
}
pub fn sum<T>(array: &Array<T>, axis: Option<isize>, keepdims: bool) -> Result<Array<T>>
where
T: Float + Clone + Add<Output = T> + Zero,
{
if array.is_empty() {
return Ok(if keepdims {
let shape = if axis.is_none() {
vec![1; array.ndim()]
} else {
let mut shape = array.shape();
let ax = if let Some(a) = axis {
if a < 0 {
(array.ndim() as isize + a) as usize
} else {
a as usize
}
} else {
0
};
if ax < shape.len() {
shape[ax] = 1;
}
shape
};
Array::zeros(&shape)
} else {
Array::zeros(&[1])
});
}
match axis {
None => {
let data = array.to_vec();
let sum_val = data.iter().fold(T::zero(), |acc, x| acc + *x);
if keepdims {
let shape = vec![1; array.ndim()];
Ok(Array::from_vec(vec![sum_val]).reshape(&shape))
} else {
Ok(Array::from_vec(vec![sum_val]))
}
}
Some(ax) => {
let axis = if ax < 0 {
(array.ndim() as isize + ax) as usize
} else {
ax as usize
};
if axis >= array.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis,
array.ndim()
)));
}
let shape = array.shape();
let axis_size = shape[axis];
let mut out_shape = shape.clone();
if keepdims {
out_shape[axis] = 1;
} else {
out_shape.remove(axis);
}
if out_shape.is_empty() {
out_shape.push(1);
}
let out_size: usize = out_shape.iter().product();
let mut result_data = vec![T::zero(); out_size];
let mut strides = vec![1; array.ndim()];
for i in (0..array.ndim() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
for out_idx in 0..out_size {
let mut indices = vec![0; array.ndim()];
let mut temp = out_idx;
for i in 0..array.ndim() {
if i < axis {
let dim_size = shape[i];
indices[i] = temp % dim_size;
temp /= dim_size;
} else if i > axis || (i == axis && keepdims) {
let dim_idx = if keepdims { i } else { i - 1 };
if dim_idx < out_shape.len() {
let dim_size = out_shape[dim_idx];
indices[i] = temp % dim_size;
temp /= dim_size;
}
}
}
let mut sum = T::zero();
for j in 0..axis_size {
indices[axis] = j;
sum = sum + array.get(&indices)?;
}
result_data[out_idx] = sum;
}
Ok(Array::from_vec(result_data).reshape(&out_shape))
}
}
}
pub fn sort<T>(
array: &Array<T>,
axis: Option<isize>,
_kind: Option<&str>,
_order: Option<&[&str]>,
) -> Result<Array<T>>
where
T: PartialOrd + Clone + Zero,
{
if array.is_empty() {
return Ok(array.clone());
}
match axis {
None => {
let mut data = array.to_vec();
data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
Ok(Array::from_vec(data))
}
Some(ax) => {
let axis = if ax < 0 {
(array.ndim() as isize + ax) as usize
} else {
ax as usize
};
if axis >= array.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis,
array.ndim()
)));
}
let shape = array.shape();
let axis_size = shape[axis];
let total_size: usize = shape.iter().product();
let mut result_data = vec![T::zero(); total_size];
let mut strides = vec![1; array.ndim()];
for i in (0..array.ndim() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let n_sorts = total_size / axis_size;
for sort_idx in 0..n_sorts {
let mut values: Vec<T> = Vec::with_capacity(axis_size);
let mut base_indices = vec![0; array.ndim()];
let mut temp = sort_idx;
for i in 0..array.ndim() {
if i != axis {
let size = shape[i];
base_indices[i] = temp % size;
temp /= size;
}
}
for j in 0..axis_size {
base_indices[axis] = j;
values.push(array.get(&base_indices)?);
}
values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
for (k, val) in values.into_iter().enumerate() {
base_indices[axis] = k;
let flat_idx = base_indices
.iter()
.enumerate()
.map(|(i, &idx)| idx * strides[i])
.sum::<usize>();
result_data[flat_idx] = val;
}
}
Ok(Array::from_vec(result_data).reshape(&shape))
}
}
}
pub fn argpartition<T>(
array: &Array<T>,
kth: usize,
axis: Option<isize>,
_kind: Option<&str>,
_order: Option<&[&str]>,
) -> Result<Array<usize>>
where
T: PartialOrd + Clone + Zero,
{
let axis = if let Some(ax) = axis {
if ax < 0 {
(array.ndim() as isize + ax) as usize
} else {
ax as usize
}
} else {
let data = array.to_vec();
let mut indices: Vec<usize> = (0..data.len()).collect();
if kth >= data.len() {
return Err(NumRs2Error::InvalidOperation(format!(
"kth ({}) out of bounds for array of size {}",
kth,
data.len()
)));
}
indices.select_nth_unstable_by(kth, |&a, &b| {
data[a]
.partial_cmp(&data[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
return Ok(Array::from_vec(indices));
};
if axis >= array.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis,
array.ndim()
)));
}
let shape = array.shape();
let axis_size = shape[axis];
if kth >= axis_size {
return Err(NumRs2Error::InvalidOperation(format!(
"kth ({}) out of bounds for axis {} of size {}",
kth, axis, axis_size
)));
}
let total_size: usize = shape.iter().product();
let mut result_data = vec![0_usize; total_size];
let mut strides = vec![1; array.ndim()];
for i in (0..array.ndim() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let n_partitions = total_size / axis_size;
for part_idx in 0..n_partitions {
let mut values_with_indices: Vec<(T, usize)> = Vec::with_capacity(axis_size);
let mut base_indices = vec![0; array.ndim()];
let mut temp = part_idx;
for i in 0..array.ndim() {
if i != axis {
let size = shape[i];
base_indices[i] = temp % size;
temp /= size;
}
}
for j in 0..axis_size {
base_indices[axis] = j;
let val = array.get(&base_indices)?;
values_with_indices.push((val, j));
}
let mut indices: Vec<usize> = (0..axis_size).collect();
indices.select_nth_unstable_by(kth, |&a, &b| {
values_with_indices[a]
.0
.partial_cmp(&values_with_indices[b].0)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (k, &idx) in indices.iter().enumerate() {
base_indices[axis] = k;
let flat_idx = base_indices
.iter()
.enumerate()
.map(|(i, &idx)| idx * strides[i])
.sum::<usize>();
result_data[flat_idx] = values_with_indices[idx].1;
}
}
Ok(Array::from_vec(result_data).reshape(&shape))
}
pub fn round<T>(array: &Array<T>) -> Result<Array<T>>
where
T: Float + Clone,
{
Ok(array.map(|x| x.round()))
}
pub fn cumulative_sum<T>(
array: &Array<T>,
axis: Option<isize>,
_out: Option<&mut Array<T>>,
) -> Result<Array<T>>
where
T: Float + Clone + Add<Output = T> + Send + Sync + 'static,
{
cumsum(array, axis, _out)
}
pub fn cumulative_prod<T>(
array: &Array<T>,
axis: Option<isize>,
_out: Option<&mut Array<T>>,
) -> Result<Array<T>>
where
T: Float + Clone + Mul<Output = T> + Send + Sync,
{
cumprod(array, axis, _out)
}