use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign, NumCast, One, ToPrimitive, Zero};
use std::fmt::Debug;
use super::conversions::{convert, convert_2d};
use crate::decomposition::svd;
use crate::error::{LinalgError, LinalgResult};
#[allow(dead_code)]
pub fn mixed_precision_solve<A, B, C, H>(
a: &ArrayView2<A>,
b: &ArrayView1<B>,
) -> LinalgResult<Array1<C>>
where
A: Clone + Debug + ToPrimitive + Copy,
B: Clone + Debug + ToPrimitive + Copy,
C: Clone + Zero + NumCast + Debug,
H: Float + Clone + NumCast + Debug + Zero + ToPrimitive + NumAssign,
{
let ashape = a.shape();
if ashape[0] != ashape[1] {
return Err(LinalgError::ShapeError(format!(
"Matrix must be square, got shape {ashape:?}"
)));
}
if ashape[0] != b.len() {
return Err(LinalgError::ShapeError(format!(
"Matrix rows ({}) must match vector length ({})",
ashape[0],
b.len()
)));
}
let a_high = convert_2d::<A, H>(a);
let b_high = convert::<B, H>(b);
let n = ashape[0];
let mut aug = Array2::<H>::zeros((n, n + 1));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = a_high[[i, j]];
}
aug[[i, n]] = b_high[i];
}
for i in 0..n {
let mut max_row = i;
let mut max_val = aug[[i, i]].abs();
for j in i + 1..n {
let val = aug[[j, i]].abs();
if val > max_val {
max_row = j;
max_val = val;
}
}
if max_val < H::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Matrix is singular".to_string(),
));
}
if max_row != i {
for j in i..=n {
let temp = aug[[i, j]];
aug[[i, j]] = aug[[max_row, j]];
aug[[max_row, j]] = temp;
}
}
for j in i + 1..n {
let factor = aug[[j, i]] / aug[[i, i]];
aug[[j, i]] = H::zero();
for k in i + 1..=n {
aug[[j, k]] = aug[[j, k]] - factor * aug[[i, k]];
}
}
}
let mut x_high = Array1::<H>::zeros(n);
for i in (0..n).rev() {
let mut sum = H::zero();
for j in i + 1..n {
sum += aug[[i, j]] * x_high[j];
}
x_high[i] = (aug[[i, n]] - sum) / aug[[i, i]];
}
let mut result = Array1::<C>::zeros(n);
for (i, &val) in x_high.iter().enumerate() {
result[i] = C::from(val).unwrap_or_else(|| C::zero());
}
Ok(result)
}
#[allow(dead_code)]
pub fn mixed_precision_cond<A, C, H>(a: &ArrayView2<A>, p: Option<H>) -> LinalgResult<C>
where
A: Clone + Debug + ToPrimitive + Copy,
C: Clone + Zero + NumCast + Debug,
H: Float
+ Clone
+ NumCast
+ Debug
+ ToPrimitive
+ 'static
+ std::iter::Sum
+ NumAssign
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
{
let a_high = convert_2d::<A, H>(a);
let (_, s, _) = svd(&a_high.view(), false, None)?;
let s_max = s.iter().cloned().fold(H::zero(), |a, b| a.max(b));
let s_min = s
.iter()
.cloned()
.filter(|&x| x > H::epsilon())
.fold(H::infinity(), |a, b| a.min(b));
let cond = match p {
None => s_max / s_min,
Some(_) => {
return Err(LinalgError::NotImplementedError(
"Only 2-norm condition number is currently implemented".to_string(),
))
}
};
C::from(cond).ok_or_else(|| {
LinalgError::ComputationError(
"Failed to convert condition number to output type".to_string(),
)
})
}
#[allow(dead_code)]
pub fn iterative_refinement_solve<A, B, C, H, W>(
a: &ArrayView2<A>,
b: &ArrayView1<B>,
max_iter: Option<usize>,
tol: Option<H>,
) -> LinalgResult<Array1<C>>
where
A: Float + NumAssign + Debug + 'static,
B: Float + NumAssign + Debug + 'static,
C: Float + NumAssign + Debug + 'static,
H: Float
+ NumAssign
+ Debug
+ 'static
+ std::iter::Sum
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
W: Float
+ NumAssign
+ Debug
+ 'static
+ std::iter::Sum
+ One
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand,
A: NumCast,
B: NumCast,
C: NumCast,
H: NumCast,
W: NumCast,
{
let max_iter = max_iter.unwrap_or(10);
let tol = tol.unwrap_or(NumCast::from(1e-8).expect("Operation failed"));
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Expected square matrix for iterative refinement, got {:?}",
a.shape()
)));
}
if a.nrows() != b.len() {
return Err(LinalgError::ShapeError(format!(
"Matrix shape {:?} is incompatible with vector length {}",
a.shape(),
b.len()
)));
}
let n = a.nrows();
let a_h: Array2<H> = convert_2d(a);
let b_h: Array1<H> = convert(b);
let a_w: Array2<W> = convert_2d(a);
let b_w: Array1<W> = convert(b);
use crate::solve::solve;
let x_w: Array1<W> = solve(&a_w.view(), &b_w.view(), None)?;
let mut x_h: Array1<H> = convert(&x_w.view());
for _iter in 0..max_iter {
let ax_h = a_h.dot(&x_h);
let mut r_h = b_h.clone();
for i in 0..n {
r_h[i] -= ax_h[i];
}
let r_norm = r_h.iter().fold(H::zero(), |max, &val| {
let abs_val = val.abs();
if abs_val > max {
abs_val
} else {
max
}
});
if r_norm < tol {
break;
}
let r_w: Array1<W> = convert(&r_h.view());
let dx_w = solve(&a_w.view(), &r_w.view(), None)?;
let dx_h: Array1<H> = convert(&dx_w.view());
for i in 0..n {
x_h[i] += dx_h[i];
}
}
let x_c: Array1<C> = convert(&x_h.view());
Ok(x_c)
}
#[allow(dead_code)]
pub fn mixed_precision_qr<A, C, H>(a: &ArrayView2<A>) -> LinalgResult<(Array2<C>, Array2<C>)>
where
A: Float + NumAssign + Debug + 'static,
C: Float + NumAssign + Debug + 'static,
H: Float
+ NumAssign
+ Debug
+ 'static
+ std::iter::Sum
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
A: NumCast,
C: NumCast,
H: NumCast,
{
let a_h: Array2<H> = convert_2d(a);
let m = a_h.nrows();
let n = a_h.ncols();
let mut q_h = Array2::<H>::eye(m);
let mut r_h = a_h.clone();
for k in 0..std::cmp::min(m - 1, n) {
let mut v = Array1::<H>::zeros(m - k);
let mut norm_x = H::zero();
for i in k..m {
norm_x += r_h[[i, k]] * r_h[[i, k]];
}
norm_x = norm_x.sqrt();
if norm_x <= NumCast::from(1e-15).expect("Operation failed") {
continue;
}
let sign = if r_h[[k, k]] < H::zero() {
H::one()
} else {
-H::one()
};
let norm_x_with_sign = sign * norm_x;
for i in 0..m - k {
if i == 0 {
v[i] = r_h[[k, k]] - norm_x_with_sign;
} else {
v[i] = r_h[[k + i, k]];
}
}
let v_norm = v.iter().fold(H::zero(), |sum, &x| sum + x * x).sqrt();
if v_norm > NumCast::from(1e-15).expect("Operation failed") {
for i in 0..m - k {
v[i] /= v_norm;
}
}
for j in 0..n {
let mut dot_product = H::zero();
for i in 0..m - k {
dot_product += v[i] * r_h[[k + i, j]];
}
for i in 0..m - k {
r_h[[k + i, j]] -= H::from(2.0).expect("Operation failed") * v[i] * dot_product;
}
}
for i in 0..m {
let mut dot_product = H::zero();
for j in 0..m - k {
dot_product += q_h[[i, k + j]] * v[j];
}
for j in 0..m - k {
q_h[[i, k + j]] -= H::from(2.0).expect("Operation failed") * dot_product * v[j];
}
}
}
for i in 0..m {
for j in 0..std::cmp::min(i, n) {
r_h[[i, j]] = H::zero();
}
}
for j in 0..m {
let mut col_norm = H::zero();
for i in 0..m {
col_norm += q_h[[i, j]] * q_h[[i, j]];
}
col_norm = col_norm.sqrt();
if col_norm > H::from(1e-15).expect("Operation failed") {
for i in 0..m {
q_h[[i, j]] /= col_norm;
}
}
}
let q_c: Array2<C> = convert_2d(&q_h.view());
let r_c: Array2<C> = convert_2d(&r_h.view());
Ok((q_c, r_c))
}
#[allow(dead_code)]
pub fn mixed_precision_svd<A, C, H>(
a: &ArrayView2<A>,
full_matrices: bool,
) -> LinalgResult<(Array2<C>, Array1<C>, Array2<C>)>
where
A: Float + NumAssign + Debug + 'static,
C: Float + NumAssign + Debug + 'static,
H: Float
+ NumAssign
+ Debug
+ 'static
+ std::iter::Sum
+ scirs2_core::ndarray::ScalarOperand
+ Send
+ Sync,
A: NumCast,
C: NumCast,
H: NumCast,
{
let a_h: Array2<H> = convert_2d(a);
let (u_h, s_h, vt_h) = svd(&a_h.view(), full_matrices, None)?;
let u_c: Array2<C> = convert_2d(&u_h.view());
let s_c: Array1<C> = convert(&s_h.view());
let vt_c: Array2<C> = convert_2d(&vt_h.view());
Ok((u_c, s_c, vt_c))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_mixed_precision_solve() {
let a = array![[2.0f32, 1.0f32], [1.0f32, 3.0f32]];
let b = array![5.0f32, 8.0f32];
let x = mixed_precision_solve::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_eq!(x.len(), 2);
assert_relative_eq!(x[0], 1.4f32, epsilon = 1e-4);
assert_relative_eq!(x[1], 2.2f32, epsilon = 1e-4);
}
#[test]
fn test_mixed_precision_cond() {
let a = array![[2.0f32, 0.0f32], [0.0f32, 2.0f32]];
let cond =
mixed_precision_cond::<f32, f32, f64>(&a.view(), None).expect("Operation failed");
assert_relative_eq!(cond, 1.0f32, epsilon = 1e-5);
let b = array![[1.0f32, 1.0f32], [1.0f32, 1.0001f32]];
let cond_b =
mixed_precision_cond::<f32, f32, f64>(&b.view(), None).expect("Operation failed");
assert!(cond_b > 1000.0f32); }
#[test]
fn test_mixed_precision_qr() {
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]];
let (q, r) = mixed_precision_qr::<f32, f32, f64>(&a.view()).expect("Operation failed");
assert_eq!(q.shape(), &[2, 2]);
assert_eq!(r.shape(), &[2, 2]);
let qt = q.t();
let qtq = qt.dot(&q);
for i in 0..2 {
for j in 0..2 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_relative_eq!(qtq[[i, j]], expected, epsilon = 1e-4);
}
}
let qr = q.dot(&r);
for i in 0..2 {
for j in 0..2 {
assert_relative_eq!(qr[[i, j]], a[[i, j]], epsilon = 1e-4);
}
}
}
#[test]
fn test_mixed_precision_svd() {
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]];
let (u, s, vt) =
mixed_precision_svd::<f32, f32, f64>(&a.view(), false).expect("Operation failed");
assert_eq!(u.shape()[0], 2);
assert_eq!(s.len(), 2);
assert_eq!(vt.shape()[1], 2);
assert!(s[0] >= s[1]);
assert!(s[1] >= 0.0);
let ut = u.t();
let uut = ut.dot(&u);
for i in 0..uut.shape()[0] {
for j in 0..uut.shape()[1] {
let expected = if i == j { 1.0 } else { 0.0 };
assert_relative_eq!(uut[[i, j]], expected, epsilon = 1e-4);
}
}
}
#[test]
fn test_mixed_precision_solve_errors() {
let a = array![[1.0f32, 2.0f32, 3.0f32], [4.0f32, 5.0f32, 6.0f32]];
let b = array![1.0f32, 2.0f32];
let result = mixed_precision_solve::<f32, f32, f32, f64>(&a.view(), &b.view());
assert!(result.is_err());
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]];
let b = array![1.0f32, 2.0f32, 3.0f32];
let result = mixed_precision_solve::<f32, f32, f32, f64>(&a.view(), &b.view());
assert!(result.is_err());
let a = array![[1.0f32, 2.0f32], [2.0f32, 4.0f32]];
let b = array![1.0f32, 2.0f32];
let result = mixed_precision_solve::<f32, f32, f32, f64>(&a.view(), &b.view());
assert!(result.is_err());
}
}