use crate::decomposition::svd;
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
fn sorted_singular_values<F>(a: &ArrayView2<F>) -> LinalgResult<Vec<F>>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
if a.nrows() == 0 || a.ncols() == 0 {
return Err(LinalgError::InvalidInputError(
"matrix_norms: matrix must be non-empty".to_string(),
));
}
let (_u, s, _vt) = svd(a, false, None)?;
let mut values: Vec<F> = s.iter().cloned().collect();
values.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
Ok(values)
}
pub fn nuclear_norm<F>(a: &ArrayView2<F>) -> LinalgResult<F>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
let svs = sorted_singular_values(a)?;
Ok(svs.iter().cloned().fold(F::zero(), |acc, s| acc + s))
}
pub fn operator_norm<F>(a: &ArrayView2<F>) -> LinalgResult<F>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
let svs = sorted_singular_values(a)?;
Ok(svs.first().cloned().unwrap_or(F::zero()))
}
pub fn matrix_pq_norm<F>(a: &ArrayView2<F>, p: F, q: F) -> LinalgResult<F>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
let one = F::one();
if p < one {
let p_f64 = p.to_f64().unwrap_or(0.0);
return Err(LinalgError::ValueError(format!(
"matrix_pq_norm: p must be >= 1, got p = {p_f64}"
)));
}
if q < one {
let q_f64 = q.to_f64().unwrap_or(0.0);
return Err(LinalgError::ValueError(format!(
"matrix_pq_norm: q must be >= 1, got q = {q_f64}"
)));
}
if a.nrows() == 0 || a.ncols() == 0 {
return Err(LinalgError::InvalidInputError(
"matrix_pq_norm: matrix must be non-empty".to_string(),
));
}
let m = a.nrows();
let ncols = a.ncols();
let inf = F::infinity();
let col_norms: Vec<F> = (0..ncols)
.map(|j| {
let col = a.column(j);
if p == inf {
col.iter()
.cloned()
.map(|x| x.abs())
.fold(F::zero(), |acc, v| if v > acc { v } else { acc })
} else {
let sum_p: F = col.iter().cloned().map(|x| x.abs().powf(p)).sum();
sum_p.powf(F::one() / p)
}
})
.collect();
if q == inf {
Ok(col_norms.iter().cloned().fold(F::zero(), |acc, v| if v > acc { v } else { acc }))
} else {
let sum_q: F = col_norms.iter().cloned().map(|v| v.powf(q)).sum();
Ok(sum_q.powf(F::one() / q))
}
}
pub fn schatten_norm<F>(a: &ArrayView2<F>, p: F) -> LinalgResult<F>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
let one = F::one();
if p < one {
let p_f64 = p.to_f64().unwrap_or(0.0);
return Err(LinalgError::ValueError(format!(
"schatten_norm: p must be >= 1, got p = {p_f64}"
)));
}
let svs = sorted_singular_values(a)?;
let inf = F::infinity();
if p == inf {
return Ok(svs.first().cloned().unwrap_or(F::zero()));
}
let sum_p: F = svs.iter().cloned().map(|s| s.powf(p)).sum();
Ok(sum_p.powf(F::one() / p))
}
pub fn ky_fan_norm<F>(a: &ArrayView2<F>, k: usize) -> LinalgResult<F>
where
F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static,
{
if k == 0 {
return Err(LinalgError::ValueError(
"ky_fan_norm: k must be >= 1".to_string(),
));
}
let svs = sorted_singular_values(a)?;
if k > svs.len() {
return Err(LinalgError::ValueError(format!(
"ky_fan_norm: k={k} exceeds number of singular values {}",
svs.len()
)));
}
Ok(svs.iter().take(k).cloned().fold(F::zero(), |acc, s| acc + s))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array2};
#[test]
fn test_nuclear_norm_diagonal() {
let a = array![[3.0_f64, 0.0], [0.0, 4.0]];
let nn = nuclear_norm(&a.view()).expect("nuclear_norm");
assert!((nn - 7.0).abs() < 1e-10);
}
#[test]
fn test_nuclear_norm_identity() {
let a = Array2::<f64>::eye(4);
let nn = nuclear_norm(&a.view()).expect("nuclear_norm identity");
assert!((nn - 4.0).abs() < 1e-10);
}
#[test]
fn test_nuclear_norm_rank1() {
let a = array![[1.0_f64, 2.0], [2.0, 4.0]];
let nn = nuclear_norm(&a.view()).expect("nuclear_norm rank1");
assert!((nn - 5.0).abs() < 1e-10, "expected 5, got {nn}");
}
#[test]
fn test_operator_norm_diagonal() {
let a = array![[3.0_f64, 0.0], [0.0, 4.0]];
let on = operator_norm(&a.view()).expect("operator_norm");
assert!((on - 4.0).abs() < 1e-10);
}
#[test]
fn test_operator_norm_identity() {
let a = Array2::<f64>::eye(3);
let on = operator_norm(&a.view()).expect("operator_norm identity");
assert!((on - 1.0).abs() < 1e-10);
}
#[test]
fn test_pq_norm_22_is_frobenius() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let pq = matrix_pq_norm(&a.view(), 2.0, 2.0).expect("pq 2,2");
let fro = (1.0_f64 + 4.0 + 9.0 + 16.0_f64).sqrt();
assert!((pq - fro).abs() < 1e-10, "expected {fro}, got {pq}");
}
#[test]
fn test_pq_norm_11_is_l1_entry() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let pq = matrix_pq_norm(&a.view(), 1.0, 1.0).expect("pq 1,1");
assert!((pq - 10.0).abs() < 1e-10, "expected 10, got {pq}");
}
#[test]
fn test_pq_norm_invalid_p() {
let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
assert!(matrix_pq_norm(&a.view(), 0.5, 2.0).is_err());
}
#[test]
fn test_pq_norm_invalid_q() {
let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
assert!(matrix_pq_norm(&a.view(), 2.0, 0.5).is_err());
}
#[test]
fn test_schatten_1_equals_nuclear() {
let a = array![[3.0_f64, 0.0], [0.0, 4.0]];
let s1 = schatten_norm(&a.view(), 1.0).expect("schatten 1");
let nn = nuclear_norm(&a.view()).expect("nuclear_norm");
assert!((s1 - nn).abs() < 1e-10);
}
#[test]
fn test_schatten_2_equals_frobenius() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let s2 = schatten_norm(&a.view(), 2.0).expect("schatten 2");
let fro: f64 = a.iter().map(|&v| v * v).sum::<f64>().sqrt();
assert!((s2 - fro).abs() < 1e-10, "schatten 2 vs frobenius");
}
#[test]
fn test_schatten_invalid_p() {
let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
assert!(schatten_norm(&a.view(), 0.5).is_err());
}
#[test]
fn test_schatten_monotone_in_p() {
let a = array![[5.0_f64, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 1.0]];
let s1 = schatten_norm(&a.view(), 1.0).expect("s1");
let s2 = schatten_norm(&a.view(), 2.0).expect("s2");
let s4 = schatten_norm(&a.view(), 4.0).expect("s4");
assert!(s1 >= s2 - 1e-10 && s2 >= s4 - 1e-10, "monotone: {s1} >= {s2} >= {s4}");
}
#[test]
fn test_ky_fan_1_is_spectral() {
let a = array![[5.0_f64, 0.0], [0.0, 3.0]];
let kf1 = ky_fan_norm(&a.view(), 1).expect("kf1");
assert!((kf1 - 5.0).abs() < 1e-10);
}
#[test]
fn test_ky_fan_max_is_nuclear() {
let a = array![[3.0_f64, 0.0], [0.0, 4.0]];
let kf2 = ky_fan_norm(&a.view(), 2).expect("kf2");
let nn = nuclear_norm(&a.view()).expect("nuclear_norm");
assert!((kf2 - nn).abs() < 1e-10);
}
#[test]
fn test_ky_fan_monotone() {
let a = array![[5.0_f64, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 1.0]];
let kf1 = ky_fan_norm(&a.view(), 1).expect("kf1");
let kf2 = ky_fan_norm(&a.view(), 2).expect("kf2");
let kf3 = ky_fan_norm(&a.view(), 3).expect("kf3");
assert!(kf1 <= kf2 + 1e-10 && kf2 <= kf3 + 1e-10);
}
#[test]
fn test_ky_fan_k0_error() {
let a = Array2::<f64>::eye(2);
assert!(ky_fan_norm(&a.view(), 0).is_err());
}
#[test]
fn test_ky_fan_k_too_large_error() {
let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
assert!(ky_fan_norm(&a.view(), 3).is_err());
}
#[test]
fn test_nuclear_ge_operator() {
let a = array![[2.0_f64, 1.0], [1.0, 3.0]];
let nn = nuclear_norm(&a.view()).expect("nuclear_norm");
let on = operator_norm(&a.view()).expect("operator_norm");
assert!(nn >= on - 1e-10, "nuclear {nn} >= operator {on}");
}
}