use crate::op::{ComputeContext, GradientContext, Op, OpError};
use crate::tensor::Tensor;
use crate::tensor_ops;
use crate::Float;
use scirs2_core::ndarray::{Array2, ArrayView2, Ix2};
pub struct Matrix1NormOp;
impl<F: Float> Op<F> for Matrix1NormOp {
fn name(&self) -> &'static str {
"Matrix1Norm"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input = ctx.input(0);
let shape = input.shape();
if shape.len() != 2 {
return Err(OpError::IncompatibleShape(
"Matrix 1-norm requires 2D matrix".into(),
));
}
let matrix = input
.view()
.into_dimensionality::<Ix2>()
.map_err(|_| OpError::IncompatibleShape("Failed to convert to 2D array".into()))?;
let norm = compute_matrix_1_norm(&matrix);
ctx.append_output(scirs2_core::ndarray::arr0(norm).into_dyn());
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let grad_output = ctx.output_grad();
let input = ctx.input(0);
let g = ctx.graph();
let input_array = match input.eval(g) {
Ok(arr) => arr,
Err(_) => {
ctx.append_input_grad(0, None);
return;
}
};
let grad_output_array = match grad_output.eval(g) {
Ok(arr) => arr,
Err(_) => {
ctx.append_input_grad(0, None);
return;
}
};
let grad_scalar = grad_output_array[[]];
if let Ok(matrix) = input_array.view().into_dimensionality::<Ix2>() {
let grad_matrix = compute_matrix_1_norm_gradient(&matrix, grad_scalar);
let grad_tensor = tensor_ops::convert_to_tensor(grad_matrix.into_dyn(), g);
ctx.append_input_grad(0, Some(grad_tensor));
return;
}
ctx.append_input_grad(0, None);
}
}
pub struct MatrixInfNormOp;
impl<F: Float> Op<F> for MatrixInfNormOp {
fn name(&self) -> &'static str {
"MatrixInfNorm"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input = ctx.input(0);
let shape = input.shape();
if shape.len() != 2 {
return Err(OpError::IncompatibleShape(
"Matrix infinity-norm requires 2D matrix".into(),
));
}
let matrix = input
.view()
.into_dimensionality::<Ix2>()
.map_err(|_| OpError::IncompatibleShape("Failed to convert to 2D array".into()))?;
let norm = compute_matrix_inf_norm(&matrix);
ctx.append_output(scirs2_core::ndarray::arr0(norm).into_dyn());
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let grad_output = ctx.output_grad();
let input = ctx.input(0);
let g = ctx.graph();
let input_array = match input.eval(g) {
Ok(arr) => arr,
Err(_) => {
ctx.append_input_grad(0, None);
return;
}
};
let grad_output_array = match grad_output.eval(g) {
Ok(arr) => arr,
Err(_) => {
ctx.append_input_grad(0, None);
return;
}
};
let grad_scalar = grad_output_array[[]];
if let Ok(matrix) = input_array.view().into_dimensionality::<Ix2>() {
let grad_matrix = compute_matrix_inf_norm_gradient(&matrix, grad_scalar);
let grad_tensor = tensor_ops::convert_to_tensor(grad_matrix.into_dyn(), g);
ctx.append_input_grad(0, Some(grad_tensor));
return;
}
ctx.append_input_grad(0, None);
}
}
pub struct Matrix2NormOp;
impl<F: Float + scirs2_core::ndarray::ScalarOperand> Op<F> for Matrix2NormOp {
fn name(&self) -> &'static str {
"Matrix2Norm"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input = ctx.input(0);
let shape = input.shape();
if shape.len() != 2 {
return Err(OpError::IncompatibleShape(
"Matrix 2-norm requires 2D matrix".into(),
));
}
let matrix = input
.view()
.into_dimensionality::<Ix2>()
.map_err(|_| OpError::IncompatibleShape("Failed to convert to 2D array".into()))?;
let norm = compute_matrix_2_norm(&matrix);
ctx.append_output(scirs2_core::ndarray::arr0(norm).into_dyn());
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let grad_output = ctx.output_grad();
let input = ctx.input(0);
let g = ctx.graph();
let input_array = match input.eval(g) {
Ok(arr) => arr,
Err(_) => {
ctx.append_input_grad(0, None);
return;
}
};
let grad_output_array = match grad_output.eval(g) {
Ok(arr) => arr,
Err(_) => {
ctx.append_input_grad(0, None);
return;
}
};
let grad_scalar = grad_output_array[[]];
if let Ok(matrix) = input_array.view().into_dimensionality::<Ix2>() {
let grad_matrix = compute_matrix_2_norm_gradient(&matrix, grad_scalar);
let grad_tensor = tensor_ops::convert_to_tensor(grad_matrix.into_dyn(), g);
ctx.append_input_grad(0, Some(grad_tensor));
return;
}
ctx.append_input_grad(0, None);
}
}
#[allow(dead_code)]
fn compute_matrix_1_norm<F: Float>(matrix: &ArrayView2<F>) -> F {
let (m, n) = matrix.dim();
let mut max_col_sum = F::zero();
for j in 0..n {
let mut col_sum = F::zero();
for i in 0..m {
col_sum += matrix[[i, j]].abs();
}
if col_sum > max_col_sum {
max_col_sum = col_sum;
}
}
max_col_sum
}
#[allow(dead_code)]
fn compute_matrix_1_norm_gradient<F: Float>(matrix: &ArrayView2<F>, gradscalar: F) -> Array2<F> {
let (m, n) = matrix.dim();
let mut grad_matrix = Array2::zeros((m, n));
let mut max_col = 0;
let mut max_col_sum = F::zero();
for j in 0..n {
let mut col_sum = F::zero();
for i in 0..m {
col_sum += matrix[[i, j]].abs();
}
if col_sum > max_col_sum {
max_col_sum = col_sum;
max_col = j;
}
}
for i in 0..m {
let elem = matrix[[i, max_col]];
grad_matrix[[i, max_col]] = if elem > F::zero() {
gradscalar
} else if elem < F::zero() {
-gradscalar
} else {
F::zero()
};
}
grad_matrix
}
#[allow(dead_code)]
fn compute_matrix_inf_norm<F: Float>(matrix: &ArrayView2<F>) -> F {
let (m, n) = matrix.dim();
let mut max_row_sum = F::zero();
for i in 0..m {
let mut row_sum = F::zero();
for j in 0..n {
row_sum += matrix[[i, j]].abs();
}
if row_sum > max_row_sum {
max_row_sum = row_sum;
}
}
max_row_sum
}
#[allow(dead_code)]
fn compute_matrix_inf_norm_gradient<F: Float>(matrix: &ArrayView2<F>, grad_scalar: F) -> Array2<F> {
let (m, n) = matrix.dim();
let mut grad_matrix = Array2::zeros((m, n));
let mut max_row = 0;
let mut max_row_sum = F::zero();
for i in 0..m {
let mut row_sum = F::zero();
for j in 0..n {
row_sum += matrix[[i, j]].abs();
}
if row_sum > max_row_sum {
max_row_sum = row_sum;
max_row = i;
}
}
for j in 0..n {
let elem = matrix[[max_row, j]];
grad_matrix[[max_row, j]] = if elem > F::zero() {
grad_scalar
} else if elem < F::zero() {
-grad_scalar
} else {
F::zero()
};
}
grad_matrix
}
#[allow(dead_code)]
fn compute_matrix_2_norm<F: Float + scirs2_core::ndarray::ScalarOperand>(
matrix: &ArrayView2<F>,
) -> F {
let (_, sigma_max) = power_iteration_2norm(
matrix,
50,
F::from(1e-8).expect("Failed to convert constant to float"),
);
sigma_max
}
#[allow(dead_code)]
fn power_iteration_2norm<F: Float + scirs2_core::ndarray::ScalarOperand>(
matrix: &ArrayView2<F>,
max_iter: usize,
tol: F,
) -> (scirs2_core::ndarray::Array1<F>, F) {
let (m, n) = matrix.dim();
let mut u = scirs2_core::ndarray::Array1::<F>::zeros(m);
u[0] = F::one();
for i in 1..m {
u[i] = F::from(0.01).expect("Failed to convert constant to float")
* F::from(i as f64).expect("Failed to convert to float");
}
let norm = u.iter().fold(F::zero(), |acc, &x| acc + x * x).sqrt();
if norm > F::epsilon() {
u.mapv_inplace(|x| x / norm);
}
let mut prev_sigma = F::zero();
for _iter in 0..max_iter {
let au = matrix.dot(&u);
let atau = matrix.t().dot(&au);
let sigma = atau.iter().fold(F::zero(), |acc, &x| acc + x * x).sqrt();
if (sigma - prev_sigma).abs() < tol {
let au_final = matrix.dot(&u);
let sigma_final = au_final
.iter()
.fold(F::zero(), |acc, &x| acc + x * x)
.sqrt();
return (u, sigma_final);
}
prev_sigma = sigma;
let norm = atau.iter().fold(F::zero(), |acc, &x| acc + x * x).sqrt();
if norm > F::epsilon() {
u = atau.mapv(|x| x / norm);
}
}
let au = matrix.dot(&u);
let sigma = au.iter().fold(F::zero(), |acc, &x| acc + x * x).sqrt();
(u, sigma)
}
#[allow(dead_code)]
fn compute_matrix_2_norm_gradient<F: Float + scirs2_core::ndarray::ScalarOperand>(
matrix: &ArrayView2<F>,
grad_scalar: F,
) -> Array2<F> {
let (m, n) = matrix.dim();
let (u, sigma) = power_iteration_2norm(
matrix,
20,
F::from(1e-6).expect("Failed to convert constant to float"),
);
let v = if sigma > F::epsilon() {
matrix.t().dot(&u) / sigma
} else {
scirs2_core::ndarray::Array1::zeros(n)
};
let mut grad_matrix = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
grad_matrix[[i, j]] = u[i] * v[j] * grad_scalar;
}
}
grad_matrix
}
#[allow(dead_code)]
pub fn norm1<'g, F: Float>(matrix: &Tensor<'g, F>) -> Tensor<'g, F> {
let g = matrix.graph();
Tensor::builder(g)
.append_input(matrix, false)
.build(Matrix1NormOp)
}
#[allow(dead_code)]
pub fn norm2<'g, F: Float + scirs2_core::ndarray::ScalarOperand>(
matrix: &Tensor<'g, F>,
) -> Tensor<'g, F> {
let g = matrix.graph();
Tensor::builder(g)
.append_input(matrix, false)
.build(Matrix2NormOp)
}
#[allow(dead_code)]
pub fn norminf<'g, F: Float>(matrix: &Tensor<'g, F>) -> Tensor<'g, F> {
let g = matrix.graph();
Tensor::builder(g)
.append_input(matrix, false)
.build(MatrixInfNormOp)
}
#[allow(dead_code)]
pub fn normfro<'g, F: Float>(matrix: &Tensor<'g, F>) -> Tensor<'g, F> {
crate::tensor_ops::norm_ops::frobenius_norm(matrix)
}