use crate::array2::Array2;
use crate::error::{Error, Result};
use crate::numeric::Float;
use crate::view2::{ArrayView2, ArrayViewMut2};
pub fn column_means<T: Float>(a: ArrayView2<'_, T>) -> Vec<T> {
if a.rows() == 0 {
return vec![T::zero(); a.cols()];
}
let denom = T::from_f64(a.rows() as f64);
let mut means = vec![T::zero(); a.cols()];
for i in 0..a.rows() {
for j in 0..a.cols() {
means[j] += a[(i, j)];
}
}
for mean in &mut means {
*mean /= denom;
}
means
}
pub fn row_means<T: Float>(a: ArrayView2<'_, T>) -> Vec<T> {
if a.cols() == 0 {
return vec![T::zero(); a.rows()];
}
let denom = T::from_f64(a.cols() as f64);
let mut means = vec![T::zero(); a.rows()];
for i in 0..a.rows() {
for j in 0..a.cols() {
means[i] += a[(i, j)];
}
means[i] /= denom;
}
means
}
pub fn column_variances<T: Float>(a: ArrayView2<'_, T>, means: &[T]) -> Result<Vec<T>> {
if means.len() != a.cols() {
return Err(Error::shape(vec![a.cols()], vec![means.len()]));
}
if a.rows() == 0 {
return Ok(vec![T::zero(); a.cols()]);
}
let denom = T::from_f64(a.rows() as f64);
let mut variances = vec![T::zero(); a.cols()];
for i in 0..a.rows() {
for j in 0..a.cols() {
let centered = a[(i, j)] - means[j];
variances[j] += centered * centered;
}
}
for variance in &mut variances {
*variance /= denom;
}
Ok(variances)
}
pub fn row_variances<T: Float>(a: ArrayView2<'_, T>, means: &[T]) -> Result<Vec<T>> {
if means.len() != a.rows() {
return Err(Error::shape(vec![a.rows()], vec![means.len()]));
}
if a.cols() == 0 {
return Ok(vec![T::zero(); a.rows()]);
}
let denom = T::from_f64(a.cols() as f64);
let mut variances = vec![T::zero(); a.rows()];
for i in 0..a.rows() {
for j in 0..a.cols() {
let centered = a[(i, j)] - means[i];
variances[i] += centered * centered;
}
variances[i] /= denom;
}
Ok(variances)
}
pub fn center_columns_inplace<T: Float>(mut a: ArrayViewMut2<'_, T>, means: &[T]) -> Result<()> {
if means.len() != a.cols() {
return Err(Error::shape(vec![a.cols()], vec![means.len()]));
}
for i in 0..a.rows() {
for j in 0..a.cols() {
a[(i, j)] -= means[j];
}
}
Ok(())
}
pub fn center_rows_inplace<T: Float>(mut a: ArrayViewMut2<'_, T>, means: &[T]) -> Result<()> {
if means.len() != a.rows() {
return Err(Error::shape(vec![a.rows()], vec![means.len()]));
}
for i in 0..a.rows() {
for j in 0..a.cols() {
a[(i, j)] -= means[i];
}
}
Ok(())
}
pub fn scale_columns_inplace<T: Float>(mut a: ArrayViewMut2<'_, T>, scales: &[T]) -> Result<()> {
if scales.len() != a.cols() {
return Err(Error::shape(vec![a.cols()], vec![scales.len()]));
}
for i in 0..a.rows() {
for j in 0..a.cols() {
a[(i, j)] *= scales[j];
}
}
Ok(())
}
pub fn scale_rows_inplace<T: Float>(mut a: ArrayViewMut2<'_, T>, scales: &[T]) -> Result<()> {
if scales.len() != a.rows() {
return Err(Error::shape(vec![a.rows()], vec![scales.len()]));
}
for i in 0..a.rows() {
for j in 0..a.cols() {
a[(i, j)] *= scales[i];
}
}
Ok(())
}
pub fn standardize_columns_inplace<T: Float>(
mut a: ArrayViewMut2<'_, T>,
) -> Result<(Vec<T>, Vec<T>)> {
let means = column_means(a.as_view());
let variances = column_variances(a.as_view(), &means)?;
let scales = variances
.iter()
.copied()
.map(|variance| {
if variance == T::zero() {
T::one()
} else {
variance.sqrt()
}
})
.collect::<Vec<_>>();
for i in 0..a.rows() {
for j in 0..a.cols() {
a[(i, j)] = (a[(i, j)] - means[j]) / scales[j];
}
}
Ok((means, scales))
}
pub fn standardize_rows_inplace<T: Float>(mut a: ArrayViewMut2<'_, T>) -> Result<(Vec<T>, Vec<T>)> {
let means = row_means(a.as_view());
let variances = row_variances(a.as_view(), &means)?;
let scales = variances
.iter()
.copied()
.map(|variance| {
if variance == T::zero() {
T::one()
} else {
variance.sqrt()
}
})
.collect::<Vec<_>>();
for i in 0..a.rows() {
for j in 0..a.cols() {
a[(i, j)] = (a[(i, j)] - means[i]) / scales[i];
}
}
Ok((means, scales))
}
pub fn centered_columns<T: Float>(a: ArrayView2<'_, T>) -> Array2<T> {
let means = column_means(a);
Array2::from_fn(a.shape(), |i, j| a[(i, j)] - means[j])
}
pub fn centered_rows<T: Float>(a: ArrayView2<'_, T>) -> Array2<T> {
let means = row_means(a);
Array2::from_fn(a.shape(), |i, j| a[(i, j)] - means[i])
}