use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign, One};
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
use crate::validation::validate_decomposition;
#[allow(dead_code)]
pub fn softmax<F>(a: &ArrayView2<F>, axis: Option<usize>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
let (nrows, ncols) = a.dim();
if nrows == 0 || ncols == 0 {
return Err(LinalgError::ShapeError(
"Matrix must be non-empty for softmax computation".to_string(),
));
}
let mut result = Array2::<F>::zeros((nrows, ncols));
match axis {
None => {
let mut max_val = a[[0, 0]];
for i in 0..nrows {
for j in 0..ncols {
if a[[i, j]] > max_val {
max_val = a[[i, j]];
}
}
}
let mut sum = F::zero();
for i in 0..nrows {
for j in 0..ncols {
let exp_val = (a[[i, j]] - max_val).exp();
result[[i, j]] = exp_val;
sum += exp_val;
}
}
for i in 0..nrows {
for j in 0..ncols {
result[[i, j]] /= sum;
}
}
}
Some(0) => {
for j in 0..ncols {
let mut max_val = a[[0, j]];
for i in 1..nrows {
if a[[i, j]] > max_val {
max_val = a[[i, j]];
}
}
let mut sum = F::zero();
for i in 0..nrows {
let exp_val = (a[[i, j]] - max_val).exp();
result[[i, j]] = exp_val;
sum += exp_val;
}
for i in 0..nrows {
result[[i, j]] /= sum;
}
}
}
Some(1) => {
for i in 0..nrows {
let mut max_val = a[[i, 0]];
for j in 1..ncols {
if a[[i, j]] > max_val {
max_val = a[[i, j]];
}
}
let mut sum = F::zero();
for j in 0..ncols {
let exp_val = (a[[i, j]] - max_val).exp();
result[[i, j]] = exp_val;
sum += exp_val;
}
for j in 0..ncols {
result[[i, j]] /= sum;
}
}
}
Some(axis_val) => {
return Err(LinalgError::InvalidInputError(format!(
"Invalid axis {} for 2D matrix. Must be 0, 1, or None",
axis_val
)));
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn sigmoid<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
let (nrows, ncols) = a.dim();
let mut result = Array2::<F>::zeros((nrows, ncols));
for i in 0..nrows {
for j in 0..ncols {
let x = a[[i, j]];
if x > F::zero() {
let exp_neg_x = (-x).exp();
result[[i, j]] = F::one() / (F::one() + exp_neg_x);
} else {
let exp_x = x.exp();
result[[i, j]] = exp_x / (F::one() + exp_x);
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn signm<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
use crate::eigen::eig;
use crate::solve::solve_multiple;
validate_decomposition(a, "Matrix sign computation", true)?;
let n = a.nrows();
let mut is_zero = true;
for i in 0..n {
for j in 0..n {
if a[[i, j]].abs() > F::epsilon() {
is_zero = false;
break;
}
}
if !is_zero {
break;
}
}
if is_zero {
return Ok(Array2::<F>::zeros((n, n))); }
let mut is_diagonal = true;
for i in 0..n {
for j in 0..n {
if i != j && a[[i, j]].abs() > F::epsilon() {
is_diagonal = false;
break;
}
}
if !is_diagonal {
break;
}
}
if is_diagonal {
let mut result = Array2::<F>::zeros((n, n));
for i in 0..n {
let val = a[[i, i]];
if val > F::zero() {
result[[i, i]] = F::one();
} else if val < F::zero() {
result[[i, i]] = -F::one();
} else {
result[[i, i]] = F::zero();
}
}
return Ok(result);
}
Err(LinalgError::ImplementationError(
"Matrix sign function for general matrices is not yet fully implemented".to_string(),
))
}