use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign, NumCast, ToPrimitive, Zero};
use std::fmt::Debug;
use crate::error::{LinalgError, LinalgResult};
#[cfg(feature = "simd")]
mod simd;
#[cfg(feature = "simd")]
pub use simd::{
simd_mixed_precision_dot_f32_f64, simd_mixed_precision_matmul_f32_f64,
simd_mixed_precision_matvec_f32_f64,
};
pub mod adaptive;
pub mod adaptive_precision;
pub mod conversions;
pub mod convert;
pub mod f16_gemm;
pub mod f32_ops;
pub mod f64_ops;
pub mod gemm;
pub mod operations;
pub mod types;
pub use conversions::{convert, convert_2d};
pub use f32_ops::{
mixed_precision_dot_f32, mixed_precision_matmul_f32_basic, mixed_precision_matvec_f32,
};
pub use f64_ops::mixed_precision_matmul_f64;
pub use adaptive::{
iterative_refinement_solve, mixed_precision_cond, mixed_precision_qr, mixed_precision_solve,
mixed_precision_svd,
};
#[allow(dead_code)]
pub fn mixed_precision_matvec<A, B, C, H>(
a: &ArrayView2<A>,
x: &ArrayView1<B>,
) -> LinalgResult<Array1<C>>
where
A: Clone + Debug + ToPrimitive + Copy,
B: Clone + Debug + ToPrimitive + Copy,
C: Clone + Zero + NumCast + Debug,
H: Float + Clone + NumCast + Debug + ToPrimitive,
{
f32_ops::mixed_precision_matvec_f32::<A, B, C, H>(a, x)
}
#[allow(dead_code)]
pub fn mixed_precision_matmul<A, B, C, H>(
a: &ArrayView2<A>,
b: &ArrayView2<B>,
) -> LinalgResult<Array2<C>>
where
A: Clone + Debug + ToPrimitive + Copy + Sync,
B: Clone + Debug + ToPrimitive + Copy + Sync,
C: Clone + Zero + NumCast + Debug + Send,
H: Float + Clone + NumCast + Debug + ToPrimitive + NumAssign + Zero + Send + Sync,
{
let ashape = a.shape();
let bshape = b.shape();
let total_elements = ashape[0] * ashape[1] + bshape[0] * bshape[1];
if total_elements < 10000 {
f32_ops::mixed_precision_matmul_f32_basic::<A, B, C, H>(a, b)
} else {
f64_ops::mixed_precision_matmul_f64::<A, B, C, H>(a, b)
}
}
#[allow(dead_code)]
pub fn mixed_precision_dot<A, B, C, H>(a: &ArrayView1<A>, b: &ArrayView1<B>) -> LinalgResult<C>
where
A: Clone + Debug + ToPrimitive + Copy,
B: Clone + Debug + ToPrimitive + Copy,
C: Clone + Zero + NumCast + Debug,
H: Float + Clone + NumCast + Debug + ToPrimitive,
{
f32_ops::mixed_precision_dot_f32::<A, B, C, H>(a, b)
}
#[allow(dead_code)]
pub fn mixed_precision_inv<A, C, H>(a: &ArrayView2<A>) -> LinalgResult<Array2<C>>
where
A: Clone + Debug + ToPrimitive + Copy,
C: Clone + Zero + NumCast + Debug,
H: Float + Clone + NumCast + Debug + ToPrimitive + NumAssign + Zero,
{
let shape = a.shape();
if shape[0] != shape[1] {
return Err(LinalgError::ShapeError(format!(
"Matrix must be square for inversion, got shape {shape:?}"
)));
}
let n = shape[0];
let a_high = conversions::convert_2d::<A, H>(a);
let mut aug = Array2::<H>::zeros((n, 2 * n));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = a_high[[i, j]];
}
}
for i in 0..n {
aug[[i, n + i]] = H::one();
}
for i in 0..n {
let mut max_row = i;
let mut max_val = aug[[i, i]].abs();
for j in i + 1..n {
let val = aug[[j, i]].abs();
if val > max_val {
max_row = j;
max_val = val;
}
}
if max_val < H::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Matrix is singular and cannot be inverted".to_string(),
));
}
if max_row != i {
for j in 0..2 * n {
let temp = aug[[i, j]];
aug[[i, j]] = aug[[max_row, j]];
aug[[max_row, j]] = temp;
}
}
let pivot = aug[[i, i]];
for j in 0..2 * n {
aug[[i, j]] /= pivot;
}
for j in 0..n {
if j != i {
let factor = aug[[j, i]];
for k in 0..2 * n {
aug[[j, k]] = aug[[j, k]] - factor * aug[[i, k]];
}
}
}
}
let mut inv_high = Array2::<H>::zeros((n, n));
for i in 0..n {
for j in 0..n {
inv_high[[i, j]] = aug[[i, n + j]];
}
}
let inv_c = conversions::convert_2d::<H, C>(&inv_high.view());
Ok(inv_c)
}
#[allow(dead_code)]
pub fn mixed_precision_det<A, C, H>(a: &ArrayView2<A>) -> LinalgResult<C>
where
A: Clone + Debug + ToPrimitive + Copy,
C: Clone + Zero + NumCast + Debug,
H: Float + Clone + NumCast + Debug + ToPrimitive + NumAssign + Zero,
{
let shape = a.shape();
if shape[0] != shape[1] {
return Err(LinalgError::ShapeError(format!(
"Matrix must be square for determinant, got shape {shape:?}"
)));
}
let n = shape[0];
let mut a_high = conversions::convert_2d::<A, H>(a);
let mut det = H::one();
for i in 0..n {
let mut max_row = i;
let mut max_val = a_high[[i, i]].abs();
for j in i + 1..n {
let val = a_high[[j, i]].abs();
if val > max_val {
max_row = j;
max_val = val;
}
}
if max_val < H::epsilon() {
return Ok(C::zero());
}
if max_row != i {
for j in 0..n {
let temp = a_high[[i, j]];
a_high[[i, j]] = a_high[[max_row, j]];
a_high[[max_row, j]] = temp;
}
det = -det; }
det *= a_high[[i, i]];
for j in i + 1..n {
let factor = a_high[[j, i]] / a_high[[i, i]];
for k in i + 1..n {
a_high[[j, k]] = a_high[[j, k]] - factor * a_high[[i, k]];
}
}
}
C::from(det).ok_or_else(|| {
LinalgError::ComputationError("Failed to convert determinant to output type".to_string())
})
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_mixed_precision_matvec() {
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]];
let x = array![0.5f32, 0.5f32];
let result = mixed_precision_matvec::<f32, f32, f32, f64>(&a.view(), &x.view())
.expect("Operation failed");
assert_eq!(result.len(), 2);
assert_relative_eq!(result[0], 1.5f32, epsilon = 1e-6);
assert_relative_eq!(result[1], 3.5f32, epsilon = 1e-6);
}
#[test]
fn test_mixed_precision_matmul() {
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]];
let b = array![[5.0f32, 6.0f32], [7.0f32, 8.0f32]];
let result = mixed_precision_matmul::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result[[0, 0]], 19.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[0, 1]], 22.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[1, 0]], 43.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[1, 1]], 50.0f32, epsilon = 1e-5);
}
#[test]
fn test_mixed_precision_dot() {
let a = array![1.0f32, 2.0f32, 3.0f32];
let b = array![4.0f32, 5.0f32, 6.0f32];
let result = mixed_precision_dot::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_relative_eq!(result, 32.0f32, epsilon = 1e-6);
}
#[test]
fn test_mixed_precision_inv() {
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]];
let result = mixed_precision_inv::<f32, f32, f64>(&a.view()).expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result[[0, 0]], -2.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[0, 1]], 1.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[1, 0]], 1.5f32, epsilon = 1e-5);
assert_relative_eq!(result[[1, 1]], -0.5f32, epsilon = 1e-5);
let identity = a.dot(&result);
assert_relative_eq!(identity[[0, 0]], 1.0f32, epsilon = 1e-4);
assert_relative_eq!(identity[[0, 1]], 0.0f32, epsilon = 1e-4);
assert_relative_eq!(identity[[1, 0]], 0.0f32, epsilon = 1e-4);
assert_relative_eq!(identity[[1, 1]], 1.0f32, epsilon = 1e-4);
}
#[test]
fn test_mixed_precision_det() {
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]];
let det = mixed_precision_det::<f32, f32, f64>(&a.view()).expect("Operation failed");
assert_relative_eq!(det, -2.0f32, epsilon = 1e-5);
let identity = array![[1.0f32, 0.0f32], [0.0f32, 1.0f32]];
let det_id =
mixed_precision_det::<f32, f32, f64>(&identity.view()).expect("Operation failed");
assert_relative_eq!(det_id, 1.0f32, epsilon = 1e-5);
let singular = array![[1.0f32, 2.0f32], [2.0f32, 4.0f32]];
let det_sing =
mixed_precision_det::<f32, f32, f64>(&singular.view()).expect("Operation failed");
assert_relative_eq!(det_sing, 0.0f32, epsilon = 1e-5);
}
#[test]
fn test_module_integration() {
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]];
let b = array![[5.0f32, 6.0f32], [7.0f32, 8.0f32]];
let x = array![1.0f32, 2.0f32];
let matmul_result = mixed_precision_matmul::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_eq!(matmul_result.shape(), &[2, 2]);
let matvec_result = mixed_precision_matvec::<f32, f32, f32, f64>(&a.view(), &x.view())
.expect("Operation failed");
assert_eq!(matvec_result.len(), 2);
let solve_result = mixed_precision_solve::<f32, f32, f32, f64>(&a.view(), &x.view())
.expect("Operation failed");
assert_eq!(solve_result.len(), 2);
let (q, r) = mixed_precision_qr::<f32, f32, f64>(&a.view()).expect("Operation failed");
assert_eq!(q.shape(), &[2, 2]);
assert_eq!(r.shape(), &[2, 2]);
}
}