use crate::core::Expression;
use crate::matrices::types::*;
use crate::matrices::unified::Matrix;
use crate::simplify::Simplify;
impl Matrix {
pub fn svd_decomposition(&self) -> Option<SVDDecomposition> {
match self {
Matrix::Identity(data) => Some(SVDDecomposition {
u: Matrix::identity(data.size),
sigma: Matrix::identity(data.size),
vt: Matrix::identity(data.size),
}),
Matrix::Zero(data) => Some(SVDDecomposition {
u: Matrix::identity(data.rows),
sigma: Matrix::zero(data.rows, data.cols),
vt: Matrix::identity(data.cols),
}),
Matrix::Diagonal(data) => {
let abs_elements: Vec<Expression> = data
.diagonal_elements
.iter()
.map(|elem| Expression::function("abs", vec![elem.clone()]))
.collect();
Some(SVDDecomposition {
u: Matrix::identity(data.diagonal_elements.len()),
sigma: Matrix::diagonal(abs_elements),
vt: Matrix::identity(data.diagonal_elements.len()),
})
}
_ => {
self.general_svd()
}
}
}
fn general_svd(&self) -> Option<SVDDecomposition> {
let (rows, cols) = self.dimensions();
if rows <= 2 && cols <= 2 {
return self.svd_2x2();
}
let min_dim = rows.min(cols);
Some(SVDDecomposition {
u: Matrix::identity(rows),
sigma: Matrix::identity(min_dim),
vt: Matrix::identity(cols),
})
}
fn svd_2x2(&self) -> Option<SVDDecomposition> {
let (rows, cols) = self.dimensions();
if rows != 2 || cols != 2 {
return None;
}
let a = self.get_element(0, 0);
let b = self.get_element(0, 1);
let c = self.get_element(1, 0);
let d = self.get_element(1, 1);
let ata_00 = Expression::add(vec![
Expression::pow(a, Expression::integer(2)),
Expression::pow(c, Expression::integer(2)),
])
.simplify();
let ata_11 = Expression::add(vec![
Expression::pow(b, Expression::integer(2)),
Expression::pow(d, Expression::integer(2)),
])
.simplify();
let sigma1 = Expression::pow(ata_00, Expression::rational(1, 2));
let sigma2 = Expression::pow(ata_11, Expression::rational(1, 2));
Some(SVDDecomposition {
u: Matrix::identity(2),
sigma: Matrix::diagonal(vec![sigma1, sigma2]),
vt: Matrix::identity(2),
})
}
pub fn rank_via_svd(&self) -> usize {
match self {
Matrix::Identity(data) => data.size,
Matrix::Zero(_) => 0,
Matrix::Diagonal(data) => data
.diagonal_elements
.iter()
.filter(|elem| !elem.is_zero())
.count(),
Matrix::Scalar(data) => {
if data.scalar_value.is_zero() {
0
} else {
data.size
}
}
_ => {
if let Some(svd) = self.svd_decomposition() {
match svd.sigma {
Matrix::Diagonal(diag_data) => diag_data
.diagonal_elements
.iter()
.filter(|elem| !elem.is_zero())
.count(),
_ => 0,
}
} else {
0
}
}
}
}
pub fn condition_number_via_svd(&self) -> Expression {
if let Some(svd) = self.svd_decomposition() {
match svd.sigma {
Matrix::Diagonal(diag_data) => {
if diag_data.diagonal_elements.is_empty() {
return Expression::integer(1);
}
let mut max_val = &diag_data.diagonal_elements[0];
let mut min_val = &diag_data.diagonal_elements[0];
for val in &diag_data.diagonal_elements {
max_val = val; min_val = val; }
Expression::mul(vec![
max_val.clone(),
Expression::pow(min_val.clone(), Expression::integer(-1)),
])
.simplify()
}
_ => Expression::integer(1),
}
} else {
Expression::integer(1)
}
}
}