use crate::op::{ComputeContext, GradientContext, Op, OpError};
use crate::tensor::Tensor;
use crate::tensor_ops::convert_to_tensor;
use crate::Float;
use scirs2_core::ndarray::{Array2, ArrayView2, Ix2};
pub struct PolarDecompositionOp;
pub struct PolarExtractOp {
pub component: usize, }
impl<F: Float + scirs2_core::ndarray::ScalarOperand> Op<F> for PolarDecompositionOp {
fn name(&self) -> &'static str {
"PolarDecomposition"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input = ctx.input(0);
let shape = input.shape();
if shape.len() != 2 {
return Err(OpError::IncompatibleShape(
"Polar decomposition requires 2D matrix".into(),
));
}
let input_2d = input
.view()
.into_dimensionality::<Ix2>()
.map_err(|_| OpError::IncompatibleShape("Failed to convert to 2D".into()))?;
let (u, p) = compute_polar_decomposition(&input_2d)?;
ctx.append_output(u.into_dyn());
ctx.append_output(p.into_dyn());
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let gy = ctx.output_grad();
let g = ctx.graph();
let grad_tensor = match gy.eval(g) {
Ok(arr) => convert_to_tensor(arr, g),
Err(_) => {
ctx.append_input_grad(0, None);
return;
}
};
ctx.append_input_grad(0, Some(grad_tensor));
}
}
impl<F: Float + scirs2_core::ndarray::ScalarOperand> Op<F> for PolarExtractOp {
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input = ctx.input(0);
let shape = input.shape();
if shape.len() != 2 {
return Err(OpError::IncompatibleShape(
"Polar decomposition requires 2D matrix".into(),
));
}
let input_2d = input
.view()
.into_dimensionality::<Ix2>()
.map_err(|_| OpError::IncompatibleShape("Failed to convert to 2D".into()))?;
let (u, p) = compute_polar_decomposition(&input_2d)?;
match self.component {
0 => ctx.append_output(u.into_dyn()),
1 => ctx.append_output(p.into_dyn()),
_ => {
return Err(OpError::Other(
"Invalid component index for polar decomposition".into(),
))
}
}
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let _gy = ctx.output_grad();
let g = ctx.graph();
let zeros = crate::tensor_ops::zeros(&crate::tensor_ops::shape(ctx.input(0)), g);
ctx.append_input_grad(0, Some(zeros));
}
}
pub struct SchurDecompositionOp;
pub struct SchurExtractOp {
pub component: usize, }
impl<F: Float + scirs2_core::ndarray::ScalarOperand> Op<F> for SchurDecompositionOp {
fn name(&self) -> &'static str {
"SchurDecomposition"
}
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input = ctx.input(0);
let shape = input.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(OpError::IncompatibleShape(
"Schur decomposition requires square matrix".into(),
));
}
let input_2d = input
.view()
.into_dimensionality::<Ix2>()
.map_err(|_| OpError::IncompatibleShape("Failed to convert to 2D".into()))?;
let (q, t) = compute_schur_decomposition(&input_2d)?;
ctx.append_output(q.into_dyn());
ctx.append_output(t.into_dyn());
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let gy = ctx.output_grad();
let g = ctx.graph();
let grad_tensor = match gy.eval(g) {
Ok(arr) => convert_to_tensor(arr, g),
Err(_) => {
ctx.append_input_grad(0, None);
return;
}
};
ctx.append_input_grad(0, Some(grad_tensor));
}
}
impl<F: Float + scirs2_core::ndarray::ScalarOperand> Op<F> for SchurExtractOp {
fn compute(&self, ctx: &mut ComputeContext<F>) -> Result<(), OpError> {
let input = ctx.input(0);
let shape = input.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(OpError::IncompatibleShape(
"Schur decomposition requires square matrix".into(),
));
}
let input_2d = input
.view()
.into_dimensionality::<Ix2>()
.map_err(|_| OpError::IncompatibleShape("Failed to convert to 2D".into()))?;
let (q, t) = compute_schur_decomposition(&input_2d)?;
match self.component {
0 => ctx.append_output(q.into_dyn()),
1 => ctx.append_output(t.into_dyn()),
_ => {
return Err(OpError::Other(
"Invalid component index for Schur decomposition".into(),
))
}
}
Ok(())
}
fn grad(&self, ctx: &mut GradientContext<F>) {
let _gy = ctx.output_grad();
let g = ctx.graph();
let zeros = crate::tensor_ops::zeros(&crate::tensor_ops::shape(ctx.input(0)), g);
ctx.append_input_grad(0, Some(zeros));
}
}
#[allow(dead_code)]
fn compute_polar_decomposition<F: Float + scirs2_core::ndarray::ScalarOperand>(
matrix: &ArrayView2<F>,
) -> Result<(Array2<F>, Array2<F>), OpError> {
let (m, n) = matrix.dim();
let _u_vec_sigma = power_iteration_svd(
matrix,
20,
F::epsilon() * F::from(100.0).expect("Failed to convert constant to float"),
);
if m == n && n <= 3 {
let ata = matrix.t().dot(matrix);
let p = matrix_sqrt(&ata.view())?;
let u = if is_invertible(&p.view()) {
let p_inv = matrix_inverse_simple(&p.view())?;
matrix.dot(&p_inv)
} else {
Array2::<F>::eye(m)
};
Ok((u, p))
} else {
let ata = matrix.t().dot(matrix);
let p = matrix_sqrt(&ata.view())?;
let u = if m == n {
Array2::<F>::eye(m)
} else {
let mut u_approx = Array2::<F>::zeros((m, n));
let min_dim = m.min(n);
for i in 0..min_dim {
u_approx[[i, i]] = F::one();
}
u_approx
};
Ok((u, p))
}
}
#[allow(dead_code)]
fn compute_schur_decomposition<F: Float + scirs2_core::ndarray::ScalarOperand>(
matrix: &ArrayView2<F>,
) -> Result<(Array2<F>, Array2<F>), OpError> {
let n = matrix.shape()[0];
let mut t = matrix.to_owned();
let mut q = Array2::<F>::eye(n);
let max_iter = 50;
let tol = F::epsilon() * F::from(100.0).expect("Failed to convert constant to float");
for _iter in 0..max_iter {
let (q_iter, r_iter) = compute_qr(&t.view())?;
t = r_iter.dot(&q_iter);
q = q.dot(&q_iter);
let mut off_diag_sum = F::zero();
for i in 1..n {
for j in 0..i {
off_diag_sum += t[[i, j]].abs();
}
}
if off_diag_sum < tol {
break;
}
}
for i in 2..n {
for j in 0..(i - 1) {
t[[i, j]] = F::zero();
}
}
Ok((q, t))
}
#[allow(dead_code)]
fn compute_qr<F: Float>(matrix: &ArrayView2<F>) -> Result<(Array2<F>, Array2<F>), OpError> {
let (m, n) = matrix.dim();
let k = m.min(n);
let mut q = Array2::<F>::zeros((m, k));
let mut r = Array2::<F>::zeros((k, n));
for j in 0..k {
for i in 0..m {
q[[i, j]] = matrix[[i, j]];
}
for i in 0..j {
let mut dot_product = F::zero();
for row in 0..m {
dot_product += q[[row, i]] * q[[row, j]];
}
r[[i, j]] = dot_product;
for row in 0..m {
q[[row, j]] = q[[row, j]] - dot_product * q[[row, i]];
}
}
let mut norm = F::zero();
for row in 0..m {
norm += q[[row, j]] * q[[row, j]];
}
norm = norm.sqrt();
if norm > F::epsilon() {
r[[j, j]] = norm;
for row in 0..m {
q[[row, j]] /= norm;
}
}
for col in (j + 1)..n {
let mut dot_product = F::zero();
for row in 0..m {
dot_product += q[[row, j]] * matrix[[row, col]];
}
r[[j, col]] = dot_product;
}
}
Ok((q, r))
}
#[allow(dead_code)]
fn matrix_sqrt<F: Float + scirs2_core::ndarray::ScalarOperand>(
matrix: &ArrayView2<F>,
) -> Result<Array2<F>, OpError> {
let n = matrix.shape()[0];
if is_diagonal(matrix) {
let mut sqrt_matrix = Array2::<F>::zeros((n, n));
for i in 0..n {
let diag_elem = matrix[[i, i]];
if diag_elem < F::zero() {
return Err(OpError::Other(
"Cannot compute square root of negative eigenvalue".into(),
));
}
sqrt_matrix[[i, i]] = diag_elem.sqrt();
}
return Ok(sqrt_matrix);
}
if n <= 2 {
let mut x = matrix.to_owned();
let max_iter = 20;
for _ in 0..max_iter {
let x_prev = x.clone();
if let Ok(x_inv) = matrix_inverse_simple(&x.view()) {
let ax_inv = matrix.dot(&x_inv);
x = (&x + &ax_inv) * F::from(0.5).expect("Failed to convert constant to float");
let diff = (&x - &x_prev).mapv(|v| v.abs()).sum();
if diff < F::epsilon() * F::from(n as f64).expect("Failed to convert to float") {
break;
}
} else {
return Ok(Array2::<F>::eye(n));
}
}
Ok(x)
} else {
let mut sqrt_matrix = Array2::<F>::zeros((n, n));
for i in 0..n {
sqrt_matrix[[i, i]] = matrix[[i, i]].abs().sqrt();
}
Ok(sqrt_matrix)
}
}
#[allow(dead_code)]
fn matrix_inverse_simple<F: Float>(matrix: &ArrayView2<F>) -> Result<Array2<F>, OpError> {
let n = matrix.shape()[0];
if n == 1 {
let elem = matrix[[0, 0]];
if elem.abs() < F::epsilon() {
return Err(OpError::Other("Matrix is singular".into()));
}
let mut inv = Array2::<F>::zeros((1, 1));
inv[[0, 0]] = F::one() / elem;
Ok(inv)
} else if n == 2 {
let a = matrix[[0, 0]];
let b = matrix[[0, 1]];
let c = matrix[[1, 0]];
let d = matrix[[1, 1]];
let det = a * d - b * c;
if det.abs() < F::epsilon() {
return Err(OpError::Other("Matrix is singular".into()));
}
let mut inv = Array2::<F>::zeros((2, 2));
inv[[0, 0]] = d / det;
inv[[0, 1]] = -b / det;
inv[[1, 0]] = -c / det;
inv[[1, 1]] = a / det;
Ok(inv)
} else {
let mut a = matrix.to_owned();
let mut inv = Array2::<F>::eye(n);
for i in 0..n {
let mut max_row = i;
for k in (i + 1)..n {
if a[[k, i]].abs() > a[[max_row, i]].abs() {
max_row = k;
}
}
if a[[max_row, i]].abs() < F::epsilon() {
return Err(OpError::Other("Matrix is singular".into()));
}
if max_row != i {
for j in 0..n {
a.swap((i, j), (max_row, j));
inv.swap((i, j), (max_row, j));
}
}
let pivot = a[[i, i]];
for j in 0..n {
a[[i, j]] /= pivot;
inv[[i, j]] /= pivot;
}
for k in 0..n {
if k != i {
let factor = a[[k, i]];
for j in 0..n {
a[[k, j]] = a[[k, j]] - factor * a[[i, j]];
inv[[k, j]] = inv[[k, j]] - factor * inv[[i, j]];
}
}
}
}
Ok(inv)
}
}
#[allow(dead_code)]
fn is_diagonal<F: Float>(matrix: &ArrayView2<F>) -> bool {
let (m, n) = matrix.dim();
for i in 0..m {
for j in 0..n {
if i != j && matrix[[i, j]].abs() > F::epsilon() {
return false;
}
}
}
true
}
#[allow(dead_code)]
fn is_invertible<F: Float>(matrix: &ArrayView2<F>) -> bool {
let n = matrix.shape()[0];
for i in 0..n {
if matrix[[i, i]].abs() < F::epsilon() {
return false;
}
}
true
}
#[allow(dead_code)]
fn power_iteration_svd<F: Float + scirs2_core::ndarray::ScalarOperand>(
matrix: &ArrayView2<F>,
max_iter: usize,
tol: F,
) -> (scirs2_core::ndarray::Array1<F>, F) {
let (m, n) = matrix.dim();
let mut u = scirs2_core::ndarray::Array1::<F>::zeros(m);
u[0] = F::one();
for i in 1..m {
u[i] = F::from(0.01).expect("Failed to convert constant to float")
* F::from(i as f64).expect("Failed to convert to float");
}
let norm = u.iter().fold(F::zero(), |acc, &x| acc + x * x).sqrt();
if norm > F::epsilon() {
u.mapv_inplace(|x| x / norm);
}
let mut prev_sigma = F::zero();
for _iter in 0..max_iter {
let au = matrix.dot(&u);
let atau = matrix.t().dot(&au);
let sigma = atau.iter().fold(F::zero(), |acc, &x| acc + x * x).sqrt();
if (sigma - prev_sigma).abs() < tol {
let au_final = matrix.dot(&u);
let sigma_final = au_final
.iter()
.fold(F::zero(), |acc, &x| acc + x * x)
.sqrt();
return (u, sigma_final);
}
prev_sigma = sigma;
let norm = atau.iter().fold(F::zero(), |acc, &x| acc + x * x).sqrt();
if norm > F::epsilon() {
u = atau.mapv(|x| x / norm);
}
}
let au = matrix.dot(&u);
let sigma = au.iter().fold(F::zero(), |acc, &x| acc + x * x).sqrt();
(u, sigma)
}
#[allow(dead_code)]
pub fn polar<'g, F: Float + scirs2_core::ndarray::ScalarOperand>(
matrix: &Tensor<'g, F>,
) -> (Tensor<'g, F>, Tensor<'g, F>) {
let g = matrix.graph();
let u = Tensor::builder(g)
.append_input(matrix, false)
.build(PolarExtractOp { component: 0 });
let p = Tensor::builder(g)
.append_input(matrix, false)
.build(PolarExtractOp { component: 1 });
(u, p)
}
#[allow(dead_code)]
pub fn schur<'g, F: Float + scirs2_core::ndarray::ScalarOperand>(
matrix: &Tensor<'g, F>,
) -> (Tensor<'g, F>, Tensor<'g, F>) {
let g = matrix.graph();
let q = Tensor::builder(g)
.append_input(matrix, false)
.build(SchurExtractOp { component: 0 });
let t = Tensor::builder(g)
.append_input(matrix, false)
.build(SchurExtractOp { component: 1 });
(q, t)
}