use super::super::jacobi::{
self, JacobiRotation, LinalgElement, apply_rotation_to_columns, argsort_desc,
compute_gram_elements, identity_matrix, normalize_columns, permute_columns,
};
use super::super::{CpuClient, CpuRuntime};
use crate::algorithm::linalg::{
SvdDecomposition, linalg_demote, linalg_promote, validate_linalg_dtype, validate_matrix_2d,
};
use crate::dtype::{DType, Element};
use crate::error::Result;
use crate::runtime::RuntimeClient;
use crate::tensor::Tensor;
pub fn svd_decompose_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
) -> Result<SvdDecomposition<CpuRuntime>> {
validate_linalg_dtype(a.dtype())?;
let (a, original_dtype) = linalg_promote(client, a)?;
let (m, n) = validate_matrix_2d(a.shape())?;
let result = match a.dtype() {
DType::F32 => svd_decompose_typed::<f32>(client, &a, m, n),
DType::F64 => svd_decompose_typed::<f64>(client, &a, m, n),
_ => unreachable!(),
}?;
Ok(SvdDecomposition {
u: linalg_demote(client, result.u, original_dtype)?,
s: linalg_demote(client, result.s, original_dtype)?,
vt: linalg_demote(client, result.vt, original_dtype)?,
})
}
fn svd_decompose_typed<T: Element + LinalgElement>(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
m: usize,
n: usize,
) -> Result<SvdDecomposition<CpuRuntime>> {
let device = client.device();
let k = m.min(n);
if m == 0 || n == 0 {
let u = Tensor::<CpuRuntime>::from_slice::<T>(&[], &[m, k], device);
let s = Tensor::<CpuRuntime>::from_slice::<T>(&[], &[k], device);
let vt = Tensor::<CpuRuntime>::from_slice::<T>(&[], &[k, n], device);
return Ok(SvdDecomposition { u, s, vt });
}
let transpose = m < n;
let (work_m, work_n) = if transpose { (n, m) } else { (m, n) };
let mut b: Vec<T> = if transpose {
let a_data: Vec<T> = a.to_vec();
let mut b_transposed = vec![T::zero(); work_m * work_n];
for i in 0..m {
for j in 0..n {
b_transposed[j * work_n + i] = a_data[i * n + j];
}
}
b_transposed
} else {
a.to_vec()
};
let work_k = work_m.min(work_n);
let mut v: Vec<T> = identity_matrix(work_n);
let eps = T::epsilon_val();
let tol = (work_n as f64) * eps;
let max_sweeps = 30;
for _sweep in 0..max_sweeps {
let mut off_diag_sum = 0.0f64;
for p in 0..work_n {
for q in (p + 1)..work_n {
let (a_pp, a_qq, a_pq) = compute_gram_elements(&b, work_m, work_n, p, q);
off_diag_sum += a_pq.to_f64() * a_pq.to_f64();
if a_pq.abs_val().to_f64() < tol * (a_pp.to_f64() * a_qq.to_f64()).sqrt() {
continue;
}
let rot = JacobiRotation::compute(a_pp.to_f64(), a_qq.to_f64(), a_pq.to_f64());
apply_rotation_to_columns(&mut b, work_m, work_n, p, q, &rot);
apply_rotation_to_columns(&mut v, work_n, work_n, p, q, &rot);
}
}
if off_diag_sum.sqrt() < tol {
break;
}
}
let singular_values = normalize_columns(&mut b, work_m, work_n, eps);
let u_data = b;
let indices = argsort_desc(&singular_values);
let s_sorted: Vec<T> = jacobi::permute_vector(&singular_values, &indices)
.into_iter()
.take(work_k)
.collect();
let u_sorted = permute_columns(&u_data, work_m, work_n, &indices, work_k);
let mut vt_sorted: Vec<T> = vec![T::zero(); work_k * work_n];
for (new_idx, &old_idx) in indices.iter().take(work_k).enumerate() {
for j in 0..work_n {
vt_sorted[new_idx * work_n + j] = v[j * work_n + old_idx];
}
}
if transpose {
let mut u_final: Vec<T> = vec![T::zero(); m * k];
for i in 0..k {
for j in 0..m {
u_final[j * k + i] = vt_sorted[i * work_n + j];
}
}
let mut vt_final: Vec<T> = vec![T::zero(); k * n];
for i in 0..work_m {
for j in 0..work_k {
vt_final[j * n + i] = u_sorted[i * work_k + j];
}
}
let u_tensor = Tensor::<CpuRuntime>::from_slice(&u_final, &[m, k], device);
let s_tensor = Tensor::<CpuRuntime>::from_slice(&s_sorted, &[k], device);
let vt_tensor = Tensor::<CpuRuntime>::from_slice(&vt_final, &[k, n], device);
Ok(SvdDecomposition {
u: u_tensor,
s: s_tensor,
vt: vt_tensor,
})
} else {
let u_tensor = Tensor::<CpuRuntime>::from_slice(&u_sorted, &[m, k], device);
let s_tensor = Tensor::<CpuRuntime>::from_slice(&s_sorted, &[k], device);
let vt_tensor = Tensor::<CpuRuntime>::from_slice(&vt_sorted, &[k, n], device);
Ok(SvdDecomposition {
u: u_tensor,
s: s_tensor,
vt: vt_tensor,
})
}
}