use super::FixedPoint;
use super::FixedVector;
use super::FixedMatrix;
use super::linalg::{
compute_tier_dot_raw, compute_tier_sub_dot_raw, compute_tier_sub_dot_compute,
upscale_to_compute, round_to_storage, givens, convergence_threshold,
convergence_threshold_tight, apply_givens_compute,
};
use crate::fixed_point::universal::fasc::stack_evaluator::compute::{
sqrt_at_compute_tier, compute_divide, downscale_to_storage,
compute_multiply, compute_add, compute_negate,
compute_is_negative, compute_is_zero,
};
use crate::fixed_point::universal::fasc::stack_evaluator::BinaryStorage;
use crate::fixed_point::core_types::errors::OverflowDetected;
#[derive(Clone, Debug)]
pub struct LUDecomposition {
pub l: FixedMatrix,
pub u: FixedMatrix,
pub perm: Vec<usize>,
pub num_swaps: usize,
}
pub fn lu_decompose(a: &FixedMatrix) -> Result<LUDecomposition, OverflowDetected> {
assert!(a.is_square(), "lu_decompose: matrix must be square");
let n = a.rows();
let mut pa = a.clone();
let mut l = FixedMatrix::new(n, n);
let mut u = FixedMatrix::new(n, n);
let mut perm: Vec<usize> = (0..n).collect();
let mut num_swaps: usize = 0;
for k in 0..n {
let mut max_abs = FixedPoint::ZERO;
let mut max_row = k;
for i in k..n {
let candidate = if k == 0 {
pa.get(i, k)
} else {
let l_row = l.row_raw_range(i, 0, k);
let u_col = u.col_raw_range(k, 0, k);
FixedPoint::from_raw(compute_tier_sub_dot_raw(pa.get(i, k).raw(), &l_row, &u_col))
};
if candidate.abs() > max_abs {
max_abs = candidate.abs();
max_row = i;
}
}
if max_abs.is_zero() {
return Err(OverflowDetected::DivisionByZero);
}
if max_row != k {
pa.swap_rows(k, max_row);
perm.swap(k, max_row);
num_swaps += 1;
for j in 0..k {
let tmp = l.get(k, j);
l.set(k, j, l.get(max_row, j));
l.set(max_row, j, tmp);
}
}
for j in k..n {
if k == 0 {
u.set(k, j, pa.get(k, j));
} else {
let l_row = l.row_raw_range(k, 0, k);
let u_col = u.col_raw_range(j, 0, k);
u.set(k, j, FixedPoint::from_raw(
compute_tier_sub_dot_raw(pa.get(k, j).raw(), &l_row, &u_col)
));
}
}
let pivot = u.get(k, k);
l.set(k, k, FixedPoint::one()); for i in (k + 1)..n {
let numerator = if k == 0 {
pa.get(i, k)
} else {
let l_row = l.row_raw_range(i, 0, k);
let u_col = u.col_raw_range(k, 0, k);
FixedPoint::from_raw(compute_tier_sub_dot_raw(pa.get(i, k).raw(), &l_row, &u_col))
};
l.set(i, k, numerator / pivot);
}
}
Ok(LUDecomposition { l, u, perm, num_swaps })
}
impl LUDecomposition {
pub fn solve(&self, b: &FixedVector) -> Result<FixedVector, OverflowDetected> {
let n = self.l.rows();
assert_eq!(b.len(), n, "LU solve: dimension mismatch");
let mut pb = FixedVector::new(n);
for i in 0..n {
pb[i] = b[self.perm[i]];
}
let mut y = FixedVector::new(n);
for i in 0..n {
if i == 0 {
y[0] = pb[0];
} else {
let l_row = self.l.row_raw_range(i, 0, i);
let y_raw: Vec<BinaryStorage> = (0..i).map(|j| y[j].raw()).collect();
y[i] = FixedPoint::from_raw(
compute_tier_sub_dot_raw(pb[i].raw(), &l_row, &y_raw)
);
}
}
let mut x = FixedVector::new(n);
for i in (0..n).rev() {
let diag = self.u.get(i, i);
if diag.is_zero() {
return Err(OverflowDetected::DivisionByZero);
}
if i == n - 1 {
x[n - 1] = y[n - 1] / diag;
} else {
let u_row = self.u.row_raw_range(i, i + 1, n);
let x_raw: Vec<BinaryStorage> = (i + 1..n).map(|j| x[j].raw()).collect();
let numerator = FixedPoint::from_raw(
compute_tier_sub_dot_raw(y[i].raw(), &u_row, &x_raw)
);
x[i] = numerator / diag;
}
}
Ok(x)
}
pub fn determinant(&self) -> FixedPoint {
let n = self.u.rows();
use crate::fixed_point::universal::fasc::stack_evaluator::compute::compute_multiply;
let mut acc = upscale_to_compute(self.u.get(0, 0).raw());
for i in 1..n {
acc = compute_multiply(acc, upscale_to_compute(self.u.get(i, i).raw()));
}
let det_raw = round_to_storage(acc);
let det = FixedPoint::from_raw(det_raw);
if self.num_swaps % 2 == 1 { -det } else { det }
}
pub fn refine(&self, a: &FixedMatrix, b: &FixedVector, x: &FixedVector) -> Result<FixedVector, OverflowDetected> {
let n = a.rows();
let mut r = FixedVector::new(n);
for i in 0..n {
let a_row = a.row_raw_range(i, 0, n);
let x_raw: Vec<BinaryStorage> = (0..n).map(|j| x[j].raw()).collect();
r[i] = FixedPoint::from_raw(
compute_tier_sub_dot_raw(b[i].raw(), &a_row, &x_raw)
);
}
let dx = self.solve(&r)?;
let mut x_refined = FixedVector::new(n);
for i in 0..n {
x_refined[i] = x[i] + dx[i];
}
Ok(x_refined)
}
pub fn inverse(&self) -> Result<FixedMatrix, OverflowDetected> {
let n = self.l.rows();
let mut inv = FixedMatrix::new(n, n);
for j in 0..n {
let mut e_j = FixedVector::new(n);
e_j[j] = FixedPoint::one();
let col = self.solve(&e_j)?;
for i in 0..n {
inv.set(i, j, col[i]);
}
}
Ok(inv)
}
}
#[derive(Clone, Debug)]
pub struct QRDecomposition {
pub q: FixedMatrix,
pub r: FixedMatrix,
}
pub fn qr_decompose(a: &FixedMatrix) -> Result<QRDecomposition, OverflowDetected> {
let m = a.rows();
let n = a.cols();
assert!(m >= n, "qr_decompose: requires m >= n");
let mut r = a.clone();
let mut q = FixedMatrix::identity(m);
let two = FixedPoint::from_int(2);
for k in 0..n {
let col_len = m - k;
let x_raw: Vec<BinaryStorage> = (k..m).map(|i| r.get(i, k).raw()).collect();
let norm_sq = FixedPoint::from_raw(compute_tier_dot_raw(&x_raw, &x_raw));
if norm_sq.is_zero() {
continue;
}
let norm_x = norm_sq.try_sqrt()?;
let x_0 = r.get(k, k);
let alpha = if x_0.is_negative() { norm_x } else { -norm_x };
let mut v = Vec::<FixedPoint>::with_capacity(col_len);
v.push(x_0 - alpha);
for i in 1..col_len {
v.push(FixedPoint::from_raw(x_raw[i]));
}
let v_raw: Vec<BinaryStorage> = v.iter().map(|fp| fp.raw()).collect();
let vtv = FixedPoint::from_raw(compute_tier_dot_raw(&v_raw, &v_raw));
if vtv.is_zero() {
continue;
}
for j in k..n {
let col_j_raw: Vec<BinaryStorage> = (k..m).map(|i| r.get(i, j).raw()).collect();
let vt_rj = FixedPoint::from_raw(compute_tier_dot_raw(&v_raw, &col_j_raw));
let scale = two * vt_rj / vtv;
for i in k..m {
let r_ij = r.get(i, j);
r.set(i, j, r_ij - scale * v[i - k]);
}
}
for i in 0..m {
let q_row_raw: Vec<BinaryStorage> = (k..m).map(|j| q.get(i, j).raw()).collect();
let qi_dot_v = FixedPoint::from_raw(compute_tier_dot_raw(&q_row_raw, &v_raw));
let scale = two * qi_dot_v / vtv;
for j in k..m {
let q_ij = q.get(i, j);
q.set(i, j, q_ij - scale * v[j - k]);
}
}
}
Ok(QRDecomposition { q, r })
}
impl QRDecomposition {
pub fn solve(&self, b: &FixedVector) -> Result<FixedVector, OverflowDetected> {
let m = self.q.rows();
let n = self.r.cols();
assert_eq!(b.len(), m, "QR solve: dimension mismatch");
let mut qtb = FixedVector::new(m);
for i in 0..m {
let q_col_raw: Vec<BinaryStorage> = (0..m).map(|j| self.q.get(j, i).raw()).collect();
let b_raw: Vec<BinaryStorage> = (0..m).map(|j| b[j].raw()).collect();
qtb[i] = FixedPoint::from_raw(compute_tier_dot_raw(&q_col_raw, &b_raw));
}
let mut x = FixedVector::new(n);
for i in (0..n).rev() {
let diag = self.r.get(i, i);
if diag.is_zero() {
return Err(OverflowDetected::DivisionByZero);
}
if i == n - 1 {
x[n - 1] = qtb[n - 1] / diag;
} else {
let r_row = self.r.row_raw_range(i, i + 1, n);
let x_raw: Vec<BinaryStorage> = (i + 1..n).map(|j| x[j].raw()).collect();
let numerator = FixedPoint::from_raw(
compute_tier_sub_dot_raw(qtb[i].raw(), &r_row, &x_raw)
);
x[i] = numerator / diag;
}
}
Ok(x)
}
}
#[derive(Clone, Debug)]
pub struct CholeskyDecomposition {
pub l: FixedMatrix,
}
pub fn cholesky_decompose(a: &FixedMatrix) -> Result<CholeskyDecomposition, OverflowDetected> {
assert!(a.is_square(), "cholesky_decompose: matrix must be square");
let n = a.rows();
let mut l = FixedMatrix::new(n, n);
for i in 0..n {
let diag_compute = if i == 0 {
upscale_to_compute(a.get(0, 0).raw())
} else {
let l_row = l.row_raw_range(i, 0, i);
compute_tier_sub_dot_compute(a.get(i, i).raw(), &l_row, &l_row)
};
if compute_is_negative(&diag_compute) || compute_is_zero(&diag_compute) {
return Err(OverflowDetected::DomainError);
}
let sqrt_compute = sqrt_at_compute_tier(diag_compute);
let l_ii_raw = downscale_to_storage(sqrt_compute)
.map_err(|_| OverflowDetected::TierOverflow)?;
let l_ii = FixedPoint::from_raw(l_ii_raw);
l.set(i, i, l_ii);
let l_ii_compute = upscale_to_compute(l_ii.raw());
for j in (i + 1)..n {
let numerator_compute = if i == 0 {
upscale_to_compute(a.get(j, i).raw())
} else {
let l_j_row = l.row_raw_range(j, 0, i);
let l_i_row = l.row_raw_range(i, 0, i);
compute_tier_sub_dot_compute(a.get(j, i).raw(), &l_j_row, &l_i_row)
};
let quotient_compute = compute_divide(numerator_compute, l_ii_compute)
.map_err(|_| OverflowDetected::DivisionByZero)?;
let l_ji_raw = downscale_to_storage(quotient_compute)
.map_err(|_| OverflowDetected::TierOverflow)?;
l.set(j, i, FixedPoint::from_raw(l_ji_raw));
}
}
Ok(CholeskyDecomposition { l })
}
impl CholeskyDecomposition {
pub fn solve(&self, b: &FixedVector) -> Result<FixedVector, OverflowDetected> {
let n = self.l.rows();
assert_eq!(b.len(), n, "Cholesky solve: dimension mismatch");
let mut y = FixedVector::new(n);
for i in 0..n {
let diag = self.l.get(i, i);
if i == 0 {
y[0] = b[0] / diag;
} else {
let l_row = self.l.row_raw_range(i, 0, i);
let y_raw: Vec<BinaryStorage> = (0..i).map(|j| y[j].raw()).collect();
let numerator = FixedPoint::from_raw(
compute_tier_sub_dot_raw(b[i].raw(), &l_row, &y_raw)
);
y[i] = numerator / diag;
}
}
let mut x = FixedVector::new(n);
for i in (0..n).rev() {
let diag = self.l.get(i, i);
if i == n - 1 {
x[n - 1] = y[n - 1] / diag;
} else {
let lt_row = self.l.col_raw_range(i, i + 1, n);
let x_raw: Vec<BinaryStorage> = (i + 1..n).map(|j| x[j].raw()).collect();
let numerator = FixedPoint::from_raw(
compute_tier_sub_dot_raw(y[i].raw(), <_row, &x_raw)
);
x[i] = numerator / diag;
}
}
Ok(x)
}
pub fn determinant(&self) -> FixedPoint {
let n = self.l.rows();
let mut det_l = FixedPoint::one();
for i in 0..n {
det_l = det_l * self.l.get(i, i);
}
det_l * det_l
}
}
#[derive(Clone, Debug)]
pub struct EigenDecomposition {
pub values: FixedVector,
pub vectors: FixedMatrix,
}
pub fn eigen_symmetric(a: &FixedMatrix) -> Result<EigenDecomposition, OverflowDetected> {
assert!(a.is_square(), "eigen_symmetric: matrix must be square");
let n = a.rows();
if n == 0 {
return Ok(EigenDecomposition {
values: FixedVector::new(0),
vectors: FixedMatrix::new(0, 0),
});
}
if n == 1 {
return Ok(EigenDecomposition {
values: FixedVector::from_slice(&[a.get(0, 0)]),
vectors: FixedMatrix::identity(1),
});
}
let mut s = a.clone();
let mut v = FixedMatrix::identity(n);
let one = FixedPoint::one();
let two = FixedPoint::from_int(2);
let half = one / two;
let diag_max = {
let mut m = FixedPoint::ZERO;
for i in 0..n {
let d = s.get(i, i).abs();
if d > m { m = d; }
}
m
};
let threshold = convergence_threshold_tight(diag_max);
let off_diag_norm_sq = |mat: &FixedMatrix| -> FixedPoint {
let mut sum = FixedPoint::ZERO;
for i in 0..n {
for j in (i + 1)..n {
let v = mat.get(i, j);
sum += v * v;
}
}
two * sum
};
let max_sweeps = 100;
let mut prev_off = off_diag_norm_sq(&s);
let mut stagnation_count = 0usize;
for _sweep in 0..max_sweeps {
let off = off_diag_norm_sq(&s);
if off <= threshold * threshold {
break;
}
if off >= prev_off {
stagnation_count += 1;
if stagnation_count >= 5 {
break;
}
} else {
stagnation_count = 0;
}
prev_off = off;
for p in 0..n {
for q in (p + 1)..n {
let a_pq = s.get(p, q);
if a_pq.abs() <= threshold {
continue; }
let a_pp = s.get(p, p);
let a_qq = s.get(q, q);
let diff = a_pp - a_qq;
let (cs, sn) = if diff.abs() <= threshold {
let sqrt2_inv = (one + one).try_sqrt()
.map(|s| one / s)
.unwrap_or(half); let sn_val = if a_pq.is_negative() { -sqrt2_inv } else { sqrt2_inv };
(sqrt2_inv, sn_val)
} else {
let tau_compute = {
let num = upscale_to_compute(diff.raw());
let den = upscale_to_compute((two * a_pq).raw());
compute_divide(num, den)
.unwrap_or(upscale_to_compute(diff.raw())) };
let one_compute = upscale_to_compute(one.raw());
let tau_sq = compute_multiply(tau_compute, tau_compute);
let disc = compute_add(one_compute, tau_sq);
let sqrt_disc = sqrt_at_compute_tier(disc);
let abs_tau = if compute_is_negative(&tau_compute) {
compute_negate(tau_compute)
} else {
tau_compute
};
let denom = compute_add(abs_tau, sqrt_disc);
let t_compute = compute_divide(one_compute, denom)
.unwrap_or(one_compute);
let t_compute = if compute_is_negative(&tau_compute) {
compute_negate(t_compute)
} else {
t_compute
};
let t_sq = compute_multiply(t_compute, t_compute);
let one_plus_tsq = compute_add(one_compute, t_sq);
let sqrt_1pt = sqrt_at_compute_tier(one_plus_tsq);
let cs_compute = compute_divide(one_compute, sqrt_1pt)
.unwrap_or(one_compute);
let sn_compute = compute_multiply(t_compute, cs_compute);
let cs_val = FixedPoint::from_raw(round_to_storage(cs_compute));
let sn_val = FixedPoint::from_raw(round_to_storage(sn_compute));
(cs_val, sn_val)
};
for r in 0..n {
if r == p || r == q { continue; }
let s_rp = s.get(r, p);
let s_rq = s.get(r, q);
let (new_rp, new_rq) = apply_givens_compute(cs, sn, s_rp, s_rq);
s.set(r, p, new_rp);
s.set(p, r, new_rp); s.set(r, q, new_rq);
s.set(q, r, new_rq); }
let a_pp = s.get(p, p);
let a_qq = s.get(q, q);
let cs_sq = cs * cs;
let sn_sq = sn * sn;
let cs_sn_2 = two * cs * sn;
let new_pp = FixedPoint::from_raw(compute_tier_dot_raw(
&[cs_sq.raw(), cs_sn_2.raw(), sn_sq.raw()],
&[a_pp.raw(), a_pq.raw(), a_qq.raw()],
));
let new_qq = FixedPoint::from_raw(compute_tier_dot_raw(
&[sn_sq.raw(), (-cs_sn_2).raw(), cs_sq.raw()],
&[a_pp.raw(), a_pq.raw(), a_qq.raw()],
));
s.set(p, p, new_pp);
s.set(q, q, new_qq);
s.set(p, q, FixedPoint::ZERO);
s.set(q, p, FixedPoint::ZERO);
for r in 0..n {
let v_rp = v.get(r, p);
let v_rq = v.get(r, q);
let (new_vp, new_vq) = apply_givens_compute(cs, sn, v_rp, v_rq);
v.set(r, p, new_vp);
v.set(r, q, new_vq);
}
}
}
}
{
let mut max_abs = FixedPoint::ZERO;
let mut max_p = 0;
let mut max_q = 1;
for p in 0..n {
for q in (p + 1)..n {
let val = s.get(p, q).abs();
if val > max_abs {
max_abs = val;
max_p = p;
max_q = q;
}
}
}
if !max_abs.is_zero() {
let p = max_p;
let q = max_q;
let a_pq = s.get(p, q);
let a_pp = s.get(p, p);
let a_qq = s.get(q, q);
let diff = a_pp - a_qq;
let (cs, sn) = if diff.abs().is_zero() {
let sqrt2_inv = (one + one).try_sqrt()
.map(|s| one / s)
.unwrap_or(half);
let sn_val = if a_pq.is_negative() { -sqrt2_inv } else { sqrt2_inv };
(sqrt2_inv, sn_val)
} else {
let tau_compute = {
let num = upscale_to_compute(diff.raw());
let den = upscale_to_compute((two * a_pq).raw());
compute_divide(num, den).unwrap_or(upscale_to_compute(diff.raw()))
};
let one_compute = upscale_to_compute(one.raw());
let tau_sq = compute_multiply(tau_compute, tau_compute);
let disc = compute_add(one_compute, tau_sq);
let sqrt_disc = sqrt_at_compute_tier(disc);
let abs_tau = if compute_is_negative(&tau_compute) { compute_negate(tau_compute) } else { tau_compute };
let denom = compute_add(abs_tau, sqrt_disc);
let t_compute = compute_divide(one_compute, denom).unwrap_or(one_compute);
let t_compute = if compute_is_negative(&tau_compute) { compute_negate(t_compute) } else { t_compute };
let t_sq = compute_multiply(t_compute, t_compute);
let sqrt_1pt = sqrt_at_compute_tier(compute_add(one_compute, t_sq));
let cs_compute = compute_divide(one_compute, sqrt_1pt).unwrap_or(one_compute);
let sn_compute = compute_multiply(t_compute, cs_compute);
(FixedPoint::from_raw(round_to_storage(cs_compute)),
FixedPoint::from_raw(round_to_storage(sn_compute)))
};
for r in 0..n {
if r == p || r == q { continue; }
let (new_rp, new_rq) = apply_givens_compute(cs, sn, s.get(r, p), s.get(r, q));
s.set(r, p, new_rp); s.set(p, r, new_rp);
s.set(r, q, new_rq); s.set(q, r, new_rq);
}
let a_pp = s.get(p, p);
let a_qq = s.get(q, q);
let cs_sq = cs * cs;
let sn_sq = sn * sn;
let cs_sn_2 = two * cs * sn;
s.set(p, p, FixedPoint::from_raw(compute_tier_dot_raw(
&[cs_sq.raw(), cs_sn_2.raw(), sn_sq.raw()],
&[a_pp.raw(), a_pq.raw(), a_qq.raw()],
)));
s.set(q, q, FixedPoint::from_raw(compute_tier_dot_raw(
&[sn_sq.raw(), (-cs_sn_2).raw(), cs_sq.raw()],
&[a_pp.raw(), a_pq.raw(), a_qq.raw()],
)));
s.set(p, q, FixedPoint::ZERO);
s.set(q, p, FixedPoint::ZERO);
for r in 0..n {
let (new_vp, new_vq) = apply_givens_compute(cs, sn, v.get(r, p), v.get(r, q));
v.set(r, p, new_vp);
v.set(r, q, new_vq);
}
}
}
let mut eigen_pairs: Vec<(FixedPoint, usize)> = (0..n)
.map(|i| (s.get(i, i), i))
.collect();
eigen_pairs.sort_by(|a, b| b.0.abs().partial_cmp(&a.0.abs()).unwrap_or(std::cmp::Ordering::Equal));
let mut values = FixedVector::new(n);
let mut vectors = FixedMatrix::new(n, n);
for (k, (val, orig_idx)) in eigen_pairs.iter().enumerate() {
values[k] = *val;
for r in 0..n {
vectors.set(r, k, v.get(r, *orig_idx));
}
}
Ok(EigenDecomposition { values, vectors })
}
#[derive(Clone, Debug)]
pub struct SVDDecomposition {
pub u: FixedMatrix,
pub sigma: FixedVector,
pub vt: FixedMatrix,
}
pub fn svd_decompose(a: &FixedMatrix) -> Result<SVDDecomposition, OverflowDetected> {
let (m, n) = (a.rows(), a.cols());
if m == 0 || n == 0 {
return Ok(SVDDecomposition {
u: FixedMatrix::identity(m),
sigma: FixedVector::new(0),
vt: FixedMatrix::identity(n),
});
}
if m < n {
let at = a.transpose();
let mut result = svd_decompose(&at)?;
let u_new = result.vt.transpose();
let vt_new = result.u.transpose();
result.u = u_new;
result.vt = vt_new;
return Ok(result);
}
let mut b = a.clone();
let mut u_acc = FixedMatrix::identity(m);
let mut v_acc = FixedMatrix::identity(n);
let two = FixedPoint::from_int(2);
let k = n.min(m);
for j in 0..k {
if j < m {
let col_len = m - j;
let x_raw: Vec<BinaryStorage> = (j..m).map(|i| b.get(i, j).raw()).collect();
let norm_sq = FixedPoint::from_raw(compute_tier_dot_raw(&x_raw, &x_raw));
if !norm_sq.is_zero() {
let norm_x = norm_sq.try_sqrt()?;
let x_0 = b.get(j, j);
let alpha = if x_0.is_negative() { norm_x } else { -norm_x };
let mut v_hh = Vec::<FixedPoint>::with_capacity(col_len);
v_hh.push(x_0 - alpha);
for i in 1..col_len {
v_hh.push(FixedPoint::from_raw(x_raw[i]));
}
let v_raw: Vec<BinaryStorage> = v_hh.iter().map(|fp| fp.raw()).collect();
let vtv = FixedPoint::from_raw(compute_tier_dot_raw(&v_raw, &v_raw));
if !vtv.is_zero() {
for c in j..n {
let col_raw: Vec<BinaryStorage> = (j..m).map(|i| b.get(i, c).raw()).collect();
let vt_col = FixedPoint::from_raw(compute_tier_dot_raw(&v_raw, &col_raw));
let scale = two * vt_col / vtv;
for i in j..m {
b.set(i, c, b.get(i, c) - scale * v_hh[i - j]);
}
}
for r in 0..m {
let u_row_raw: Vec<BinaryStorage> = (j..m).map(|c| u_acc.get(r, c).raw()).collect();
let dot = FixedPoint::from_raw(compute_tier_dot_raw(&u_row_raw, &v_raw));
let scale = two * dot / vtv;
for c in j..m {
u_acc.set(r, c, u_acc.get(r, c) - scale * v_hh[c - j]);
}
}
}
}
}
if j + 1 < n {
let row_start = j + 1;
let row_len = n - row_start;
if row_len > 0 {
let x_raw: Vec<BinaryStorage> = (row_start..n).map(|c| b.get(j, c).raw()).collect();
let norm_sq = FixedPoint::from_raw(compute_tier_dot_raw(&x_raw, &x_raw));
if !norm_sq.is_zero() {
let norm_x = norm_sq.try_sqrt()?;
let x_0 = b.get(j, row_start);
let alpha = if x_0.is_negative() { norm_x } else { -norm_x };
let mut v_hh = Vec::<FixedPoint>::with_capacity(row_len);
v_hh.push(x_0 - alpha);
for i in 1..row_len {
v_hh.push(FixedPoint::from_raw(x_raw[i]));
}
let v_raw: Vec<BinaryStorage> = v_hh.iter().map(|fp| fp.raw()).collect();
let vtv = FixedPoint::from_raw(compute_tier_dot_raw(&v_raw, &v_raw));
if !vtv.is_zero() {
for r in j..m {
let row_raw: Vec<BinaryStorage> = (row_start..n).map(|c| b.get(r, c).raw()).collect();
let dot = FixedPoint::from_raw(compute_tier_dot_raw(&row_raw, &v_raw));
let scale = two * dot / vtv;
for c in row_start..n {
b.set(r, c, b.get(r, c) - scale * v_hh[c - row_start]);
}
}
for r in 0..n {
let v_row_raw: Vec<BinaryStorage> = (row_start..n).map(|c| v_acc.get(r, c).raw()).collect();
let dot = FixedPoint::from_raw(compute_tier_dot_raw(&v_row_raw, &v_raw));
let scale = two * dot / vtv;
for c in row_start..n {
v_acc.set(r, c, v_acc.get(r, c) - scale * v_hh[c - row_start]);
}
}
}
}
}
}
}
let mut d: Vec<FixedPoint> = (0..n).map(|i| b.get(i, i)).collect();
let mut e: Vec<FixedPoint> = (0..n.saturating_sub(1)).map(|i| b.get(i, i + 1)).collect();
let max_iter = 30 * n * n; let mut iter_count = 0usize;
let mut q_end = n;
while q_end > 1 && iter_count < max_iter {
let mut found_active = false;
for idx in (1..q_end).rev() {
let thresh_val = convergence_threshold(d[idx].abs().max(d[idx - 1].abs()));
if e[idx - 1].abs() <= thresh_val {
if idx == q_end - 1 {
q_end -= 1; } else {
found_active = true;
break;
}
} else {
found_active = true;
break;
}
}
if !found_active || q_end <= 1 {
break;
}
let q = q_end - 1; let mut p = q;
while p > 0 {
let thresh_val = convergence_threshold(d[p].abs().max(d[p - 1].abs()));
if e[p - 1].abs() <= thresh_val {
break;
}
p -= 1;
}
{
let mut deflated = false;
{
let d_thresh = convergence_threshold(
if q > 0 { d[q - 1].abs().max(e[q - 1].abs()) }
else { e[0].abs().max(FixedPoint::one()) }
);
if d[q].abs() <= d_thresh {
let mut bulge = e[q - 1];
e[q - 1] = FixedPoint::ZERO;
for j in (p..q).rev() {
let (cs, sn) = givens(d[j], bulge);
d[j] = cs * d[j] + sn * bulge;
if j > p {
bulge = -sn * e[j - 1];
e[j - 1] = cs * e[j - 1];
}
for r in 0..n {
let v_rj = v_acc.get(r, j);
let v_rq = v_acc.get(r, q);
v_acc.set(r, j, cs * v_rj + sn * v_rq);
v_acc.set(r, q, -sn * v_rj + cs * v_rq);
}
}
deflated = true;
}
}
if !deflated {
for i in p..q {
let d_thresh = convergence_threshold(
e[i].abs().max(
if i > 0 && i - 1 < e.len() { e[i.saturating_sub(1)].abs() }
else { FixedPoint::one() }
)
);
if d[i].abs() <= d_thresh {
let mut bulge = e[i];
e[i] = FixedPoint::ZERO;
for j in (i + 1)..=q {
let (cs, sn) = givens(d[j], bulge);
d[j] = cs * d[j] + sn * bulge;
if j < q {
bulge = -sn * e[j];
e[j] = cs * e[j];
}
for r in 0..m {
let u_ri = u_acc.get(r, i);
let u_rj = u_acc.get(r, j);
u_acc.set(r, i, cs * u_ri + sn * u_rj);
u_acc.set(r, j, -sn * u_ri + cs * u_rj);
}
}
deflated = true;
break;
}
}
}
if deflated {
iter_count += 1;
continue;
}
}
let shift = {
let dq = d[q];
let eq_1 = e[q - 1];
let dq_1 = d[q - 1];
let f = dq_1 * dq_1 + if q >= 2 { e[q - 2] * e[q - 2] } else { FixedPoint::ZERO };
let g = dq * dq + eq_1 * eq_1;
let h = dq_1 * eq_1;
let half = FixedPoint::one() / FixedPoint::from_int(2);
let diff = (f - g) * half;
if diff.is_zero() && h.is_zero() {
g
} else {
let disc_sq = diff * diff + h * h;
let disc = disc_sq.try_sqrt().unwrap_or(diff.abs());
let signed_disc = if diff.is_negative() { -disc } else { disc };
g - h * h / (diff + signed_disc)
}
};
let mut x = d[p] * d[p] - shift;
let mut z = d[p] * e[p];
for i in p..q {
let (cs, sn) = givens(x, z);
if i > p {
e[i - 1] = FixedPoint::from_raw(compute_tier_dot_raw(
&[cs.raw(), sn.raw()], &[e[i - 1].raw(), z.raw()]
));
}
let old_di = d[i];
let old_ei = e[i];
let (new_di, new_ei) = apply_givens_compute(cs, sn, old_di, old_ei);
d[i] = new_di;
e[i] = new_ei;
let bulge = sn * d[i + 1]; d[i + 1] = cs * d[i + 1];
for r in 0..n {
let (new_v0, new_v1) = apply_givens_compute(
cs, sn, v_acc.get(r, i), v_acc.get(r, i + 1));
v_acc.set(r, i, new_v0);
v_acc.set(r, i + 1, new_v1);
}
x = d[i];
z = bulge;
let (cs2, sn2) = givens(x, z);
d[i] = FixedPoint::from_raw(compute_tier_dot_raw(
&[cs2.raw(), sn2.raw()], &[d[i].raw(), bulge.raw()]
));
let old_ei = e[i];
let old_di1 = d[i + 1];
let (new_ei, new_di1) = apply_givens_compute(cs2, sn2, old_ei, old_di1);
e[i] = new_ei;
d[i + 1] = new_di1;
for r in 0..m {
let (new_u0, new_u1) = apply_givens_compute(
cs2, sn2, u_acc.get(r, i), u_acc.get(r, i + 1));
u_acc.set(r, i, new_u0);
u_acc.set(r, i + 1, new_u1);
}
if i + 1 < q {
x = e[i];
z = sn2 * e[i + 1];
e[i + 1] = cs2 * e[i + 1];
}
}
iter_count += 1;
}
for i in 0..n {
if d[i].is_negative() {
d[i] = -d[i];
for r in 0..n {
v_acc.set(r, i, -v_acc.get(r, i));
}
}
}
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| d[b].partial_cmp(&d[a]).unwrap_or(std::cmp::Ordering::Equal));
let mut sigma = FixedVector::new(n);
let mut u_sorted = FixedMatrix::new(m, m);
let mut vt_sorted = FixedMatrix::new(n, n);
for (new_idx, &old_idx) in indices.iter().enumerate() {
sigma[new_idx] = d[old_idx];
for r in 0..m {
u_sorted.set(r, new_idx, u_acc.get(r, old_idx));
}
for r in 0..n {
vt_sorted.set(new_idx, r, v_acc.get(r, old_idx));
}
}
for new_idx in n..m {
for r in 0..m {
u_sorted.set(r, new_idx, u_acc.get(r, new_idx));
}
}
Ok(SVDDecomposition {
u: u_sorted,
sigma,
vt: vt_sorted,
})
}
#[derive(Clone, Debug)]
pub struct SchurDecomposition {
pub q: FixedMatrix,
pub t: FixedMatrix,
}
pub fn schur_decompose(a: &FixedMatrix) -> Result<SchurDecomposition, OverflowDetected> {
assert!(a.is_square(), "schur_decompose: matrix must be square");
let n = a.rows();
if n <= 1 {
return Ok(SchurDecomposition {
q: FixedMatrix::identity(n),
t: a.clone(),
});
}
let mut h = a.clone();
let mut q_acc = FixedMatrix::identity(n);
let two = FixedPoint::from_int(2);
for k in 0..n.saturating_sub(2) {
let col_len = n - k - 1;
let start = k + 1;
let x_raw: Vec<BinaryStorage> = (start..n).map(|i| h.get(i, k).raw()).collect();
let norm_sq = FixedPoint::from_raw(compute_tier_dot_raw(&x_raw, &x_raw));
if norm_sq.is_zero() {
continue;
}
let norm_x = norm_sq.try_sqrt()?;
let x_0 = h.get(start, k);
let alpha = if x_0.is_negative() { norm_x } else { -norm_x };
let mut v_hh = Vec::<FixedPoint>::with_capacity(col_len);
v_hh.push(x_0 - alpha);
for i in 1..col_len {
v_hh.push(FixedPoint::from_raw(x_raw[i]));
}
let v_raw: Vec<BinaryStorage> = v_hh.iter().map(|fp| fp.raw()).collect();
let vtv = FixedPoint::from_raw(compute_tier_dot_raw(&v_raw, &v_raw));
if vtv.is_zero() {
continue;
}
for c in 0..n {
let col_raw: Vec<BinaryStorage> = (start..n).map(|i| h.get(i, c).raw()).collect();
let dot = FixedPoint::from_raw(compute_tier_dot_raw(&v_raw, &col_raw));
let scale = two * dot / vtv;
for i in start..n {
h.set(i, c, h.get(i, c) - scale * v_hh[i - start]);
}
}
for r in 0..n {
let row_raw: Vec<BinaryStorage> = (start..n).map(|c| h.get(r, c).raw()).collect();
let dot = FixedPoint::from_raw(compute_tier_dot_raw(&row_raw, &v_raw));
let scale = two * dot / vtv;
for c in start..n {
h.set(r, c, h.get(r, c) - scale * v_hh[c - start]);
}
}
for r in 0..n {
let q_row_raw: Vec<BinaryStorage> = (start..n).map(|c| q_acc.get(r, c).raw()).collect();
let dot = FixedPoint::from_raw(compute_tier_dot_raw(&q_row_raw, &v_raw));
let scale = two * dot / vtv;
for c in start..n {
q_acc.set(r, c, q_acc.get(r, c) - scale * v_hh[c - start]);
}
}
}
let max_iter = 30 * n * n;
let mut iter_count = 0usize;
let mut nn = n;
while nn > 2 && iter_count < max_iter {
let thresh = convergence_threshold(
h.get(nn - 1, nn - 1).abs().max(h.get(nn - 2, nn - 2).abs())
);
if h.get(nn - 1, nn - 2).abs() <= thresh {
nn -= 1;
continue;
}
if nn >= 3 {
let thresh2 = convergence_threshold(
h.get(nn - 2, nn - 2).abs().max(h.get(nn - 3, nn - 3).abs())
);
if h.get(nn - 2, nn - 3).abs() <= thresh2 {
nn -= 2;
continue;
}
}
let mut l = nn - 2;
while l > 0 {
let thresh_l = convergence_threshold(
h.get(l, l).abs().max(h.get(l - 1, l - 1).abs())
);
if h.get(l, l - 1).abs() <= thresh_l {
break;
}
l -= 1;
}
if l == nn - 2 {
let a11 = h.get(l, l);
let a12 = h.get(l, l + 1);
let a21 = h.get(l + 1, l);
let a22 = h.get(l + 1, l + 1);
let half = FixedPoint::one() / two;
let d_val = (a11 - a22) * half;
let mu = if d_val.is_zero() && (a12 * a21).is_zero() {
a22
} else {
let disc_sq = d_val * d_val + a12 * a21;
let disc = disc_sq.abs().try_sqrt().unwrap_or(d_val.abs());
let signed_disc = if d_val.is_negative() { -disc } else { disc };
a22 - a21 * a12 / (d_val + signed_disc)
};
let x_val = h.get(l, l) - mu;
let y_val = h.get(l + 1, l);
let (cs, sn) = givens(x_val, y_val);
for c in 0..n {
let (new0, new1) = apply_givens_compute(
cs, sn, h.get(l, c), h.get(l + 1, c));
h.set(l, c, new0);
h.set(l + 1, c, new1);
}
for r in 0..nn {
let (new0, new1) = apply_givens_compute(
cs, sn, h.get(r, l), h.get(r, l + 1));
h.set(r, l, new0);
h.set(r, l + 1, new1);
}
for r in 0..n {
let (new0, new1) = apply_givens_compute(
cs, sn, q_acc.get(r, l), q_acc.get(r, l + 1));
q_acc.set(r, l, new0);
q_acc.set(r, l + 1, new1);
}
iter_count += 1;
continue;
}
let s = h.get(nn - 2, nn - 2) + h.get(nn - 1, nn - 1); let p = h.get(nn - 2, nn - 2) * h.get(nn - 1, nn - 1)
- h.get(nn - 2, nn - 1) * h.get(nn - 1, nn - 2);
let h_ll = h.get(l, l);
let h_l1l = h.get(l + 1, l);
let h_ll1 = h.get(l, l + 1);
let h_l1l1 = h.get(l + 1, l + 1);
let mut x = h_ll * h_ll + h_ll1 * h_l1l - s * h_ll + p;
let mut y = h_l1l * (h_ll + h_l1l1 - s);
let mut z = if l + 2 < nn { h_l1l * h.get(l + 2, l + 1) } else { FixedPoint::ZERO };
for k in l..nn.saturating_sub(2) {
let col_size = if k + 2 < nn { 3 } else { 2 };
if col_size == 3 {
let vec_raw = [x.raw(), y.raw(), z.raw()];
let norm_sq = FixedPoint::from_raw(compute_tier_dot_raw(&vec_raw, &vec_raw));
if norm_sq.is_zero() { break; }
let norm_v = norm_sq.try_sqrt()?;
let alpha = if x.is_negative() { norm_v } else { -norm_v };
let v0 = x - alpha;
let v1 = y;
let v2 = z;
let v_raw = [v0.raw(), v1.raw(), v2.raw()];
let vtv = FixedPoint::from_raw(compute_tier_dot_raw(&v_raw, &v_raw));
if vtv.is_zero() { break; }
for c in 0..n {
let col_raw = [h.get(k, c).raw(), h.get(k + 1, c).raw(), h.get(k + 2, c).raw()];
let dot_val = FixedPoint::from_raw(compute_tier_dot_raw(&v_raw, &col_raw));
let scale = two * dot_val / vtv;
h.set(k, c, h.get(k, c) - scale * v0);
h.set(k + 1, c, h.get(k + 1, c) - scale * v1);
h.set(k + 2, c, h.get(k + 2, c) - scale * v2);
}
let c_end = nn.min(k + 4);
for r in 0..c_end {
let row_raw = [h.get(r, k).raw(), h.get(r, k + 1).raw(), h.get(r, k + 2).raw()];
let dot_val = FixedPoint::from_raw(compute_tier_dot_raw(&row_raw, &v_raw));
let scale = two * dot_val / vtv;
h.set(r, k, h.get(r, k) - scale * v0);
h.set(r, k + 1, h.get(r, k + 1) - scale * v1);
h.set(r, k + 2, h.get(r, k + 2) - scale * v2);
}
for r in 0..n {
let q_raw = [q_acc.get(r, k).raw(), q_acc.get(r, k + 1).raw(), q_acc.get(r, k + 2).raw()];
let dot_val = FixedPoint::from_raw(compute_tier_dot_raw(&q_raw, &v_raw));
let scale = two * dot_val / vtv;
q_acc.set(r, k, q_acc.get(r, k) - scale * v0);
q_acc.set(r, k + 1, q_acc.get(r, k + 1) - scale * v1);
q_acc.set(r, k + 2, q_acc.get(r, k + 2) - scale * v2);
}
} else {
let (cs, sn) = givens(x, y);
for c in 0..n {
let (new0, new1) = apply_givens_compute(
cs, sn, h.get(k, c), h.get(k + 1, c));
h.set(k, c, new0);
h.set(k + 1, c, new1);
}
for r in 0..nn {
let (new0, new1) = apply_givens_compute(
cs, sn, h.get(r, k), h.get(r, k + 1));
h.set(r, k, new0);
h.set(r, k + 1, new1);
}
for r in 0..n {
let (new0, new1) = apply_givens_compute(
cs, sn, q_acc.get(r, k), q_acc.get(r, k + 1));
q_acc.set(r, k, new0);
q_acc.set(r, k + 1, new1);
}
}
if k + 3 < nn {
x = h.get(k + 1, k);
y = h.get(k + 2, k);
z = if k + 3 < nn { h.get(k + 3, k) } else { FixedPoint::ZERO };
} else if k + 2 < nn {
x = h.get(k + 1, k);
y = h.get(k + 2, k);
}
}
iter_count += 1;
}
for i in 1..n {
let thresh = convergence_threshold(
h.get(i, i).abs().max(h.get(i - 1, i - 1).abs())
);
if h.get(i, i - 1).abs() <= thresh {
h.set(i, i - 1, FixedPoint::ZERO);
}
}
Ok(SchurDecomposition { q: q_acc, t: h })
}