use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign, One};
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
use crate::validation::validate_decomposition;
#[allow(dead_code)]
pub fn softmax<F>(a: &ArrayView2<F>, axis: Option<usize>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
let (nrows, ncols) = a.dim();
if nrows == 0 || ncols == 0 {
return Err(LinalgError::ShapeError(
"Matrix must be non-empty for softmax computation".to_string(),
));
}
let mut result = Array2::<F>::zeros((nrows, ncols));
match axis {
None => {
let mut max_val = a[[0, 0]];
for i in 0..nrows {
for j in 0..ncols {
if a[[i, j]] > max_val {
max_val = a[[i, j]];
}
}
}
let mut sum = F::zero();
for i in 0..nrows {
for j in 0..ncols {
let exp_val = (a[[i, j]] - max_val).exp();
result[[i, j]] = exp_val;
sum += exp_val;
}
}
for i in 0..nrows {
for j in 0..ncols {
result[[i, j]] /= sum;
}
}
}
Some(0) => {
for j in 0..ncols {
let mut max_val = a[[0, j]];
for i in 1..nrows {
if a[[i, j]] > max_val {
max_val = a[[i, j]];
}
}
let mut sum = F::zero();
for i in 0..nrows {
let exp_val = (a[[i, j]] - max_val).exp();
result[[i, j]] = exp_val;
sum += exp_val;
}
for i in 0..nrows {
result[[i, j]] /= sum;
}
}
}
Some(1) => {
for i in 0..nrows {
let mut max_val = a[[i, 0]];
for j in 1..ncols {
if a[[i, j]] > max_val {
max_val = a[[i, j]];
}
}
let mut sum = F::zero();
for j in 0..ncols {
let exp_val = (a[[i, j]] - max_val).exp();
result[[i, j]] = exp_val;
sum += exp_val;
}
for j in 0..ncols {
result[[i, j]] /= sum;
}
}
}
Some(axis_val) => {
return Err(LinalgError::InvalidInputError(format!(
"Invalid axis {} for 2D matrix. Must be 0, 1, or None",
axis_val
)));
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn sigmoid<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
let (nrows, ncols) = a.dim();
let mut result = Array2::<F>::zeros((nrows, ncols));
for i in 0..nrows {
for j in 0..ncols {
let x = a[[i, j]];
if x > F::zero() {
let exp_neg_x = (-x).exp();
result[[i, j]] = F::one() / (F::one() + exp_neg_x);
} else {
let exp_x = x.exp();
result[[i, j]] = exp_x / (F::one() + exp_x);
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn signm<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
use crate::eigen::eig;
use crate::solve::solve_multiple;
validate_decomposition(a, "Matrix sign computation", true)?;
let n = a.nrows();
let mut is_zero = true;
for i in 0..n {
for j in 0..n {
if a[[i, j]].abs() > F::epsilon() {
is_zero = false;
break;
}
}
if !is_zero {
break;
}
}
if is_zero {
return Ok(Array2::<F>::zeros((n, n))); }
let mut is_diagonal = true;
for i in 0..n {
for j in 0..n {
if i != j && a[[i, j]].abs() > F::epsilon() {
is_diagonal = false;
break;
}
}
if !is_diagonal {
break;
}
}
if is_diagonal {
let mut result = Array2::<F>::zeros((n, n));
for i in 0..n {
let val = a[[i, i]];
if val > F::zero() {
result[[i, i]] = F::one();
} else if val < F::zero() {
result[[i, i]] = -F::one();
} else {
result[[i, i]] = F::zero();
}
}
return Ok(result);
}
let half = F::from(0.5)
.ok_or_else(|| LinalgError::ComputationError("Failed to convert 0.5".to_string()))?;
let max_iter = 100usize;
let tol = F::from(1e-12).unwrap_or(F::epsilon());
let mut x = a.to_owned();
for _iter in 0..max_iter {
let eye = Array2::<F>::eye(n);
let x_inv = solve_multiple(&x.view(), &eye.view(), None)?;
let x_new = (&x + &x_inv) * half;
let diff_norm: F = x_new
.iter()
.zip(x.iter())
.map(|(&a, &b)| {
let d = a - b;
d * d
})
.fold(F::zero(), |acc, v| acc + v)
.sqrt();
let x_norm: F = x_new
.iter()
.map(|&v| v * v)
.fold(F::zero(), |acc, v| acc + v)
.sqrt();
x = x_new;
let rel_err = if x_norm > F::epsilon() {
diff_norm / x_norm
} else {
diff_norm
};
if rel_err < tol {
break;
}
}
Ok(x)
}