use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
use scirs2_core::numeric::{Float, NumAssign, NumCast, ToPrimitive, Zero};
use std::fmt::Debug;
use super::conversions::{convert, convert_2d};
use crate::error::{LinalgError, LinalgResult};
#[allow(dead_code)]
pub fn mixed_precision_matvec_f32<A, B, C, H>(
a: &ArrayView2<A>,
x: &ArrayView1<B>,
) -> LinalgResult<Array1<C>>
where
A: Clone + Debug + ToPrimitive + Copy,
B: Clone + Debug + ToPrimitive + Copy,
C: Clone + Zero + NumCast + Debug,
H: Float + Clone + NumCast + Debug + ToPrimitive,
{
let ashape = a.shape();
if ashape[1] != x.len() {
return Err(LinalgError::ShapeError(format!(
"Matrix columns ({}) must match vector length ({})",
ashape[1],
x.len()
)));
}
let a_high = convert_2d::<A, H>(a);
let x_high = convert::<B, H>(x);
let mut result_high = Array1::<H>::zeros(ashape[0]);
for i in 0..ashape[0] {
let row = a_high.index_axis(Axis(0), i);
let mut sum = H::zero();
for j in 0..ashape[1] {
sum = sum + row[j] * x_high[j];
}
result_high[i] = sum;
}
let mut result = Array1::<C>::zeros(ashape[0]);
for (i, &val) in result_high.iter().enumerate() {
result[i] = C::from(val).unwrap_or_else(|| C::zero());
}
Ok(result)
}
#[allow(dead_code)]
pub fn mixed_precision_matmul_f32_basic<A, B, C, H>(
a: &ArrayView2<A>,
b: &ArrayView2<B>,
) -> LinalgResult<Array2<C>>
where
A: Clone + Debug + ToPrimitive + Copy,
B: Clone + Debug + ToPrimitive + Copy,
C: Clone + Zero + NumCast + Debug,
H: Float + Clone + NumCast + Debug + ToPrimitive + NumAssign + Zero,
{
let ashape = a.shape();
let bshape = b.shape();
if ashape[1] != bshape[0] {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions incompatible for multiplication: {}x{} and {}x{}",
ashape[0], ashape[1], bshape[0], bshape[1]
)));
}
let m = ashape[0];
let n = bshape[1];
let k = ashape[1];
let a_high = convert_2d::<A, H>(a);
let b_high = convert_2d::<B, H>(b);
if m <= 32 && n <= 32 && k <= 32 {
let mut c_high = Array2::<H>::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut sum = H::zero();
for l in 0..k {
sum += a_high[[i, l]] * b_high[[l, j]];
}
c_high[[i, j]] = sum;
}
}
let mut c = Array2::<C>::zeros((m, n));
for i in 0..m {
for j in 0..n {
c[[i, j]] = C::from(c_high[[i, j]]).unwrap_or_else(|| C::zero());
}
}
return Ok(c);
}
const BLOCK_SIZE: usize = 32;
let mut c_high = Array2::<H>::zeros((m, n));
let block_m = m.div_ceil(BLOCK_SIZE);
let block_n = n.div_ceil(BLOCK_SIZE);
let block_k = k.div_ceil(BLOCK_SIZE);
for bi in 0..block_m {
let i_start = bi * BLOCK_SIZE;
let i_end = std::cmp::min(i_start + BLOCK_SIZE, m);
for bj in 0..block_n {
let j_start = bj * BLOCK_SIZE;
let j_end = std::cmp::min(j_start + BLOCK_SIZE, n);
for i in i_start..i_end {
for j in j_start..j_end {
c_high[[i, j]] = H::zero();
}
}
for bk in 0..block_k {
let k_start = bk * BLOCK_SIZE;
let k_end = std::cmp::min(k_start + BLOCK_SIZE, k);
for i in i_start..i_end {
for j in j_start..j_end {
let mut sum = H::zero();
for l in k_start..k_end {
sum += a_high[[i, l]] * b_high[[l, j]];
}
c_high[[i, j]] += sum;
}
}
}
}
}
let mut c = Array2::<C>::zeros((m, n));
for i in 0..m {
for j in 0..n {
c[[i, j]] = C::from(c_high[[i, j]]).unwrap_or_else(|| C::zero());
}
}
Ok(c)
}
#[allow(dead_code)]
pub fn mixed_precision_dot_f32<A, B, C, H>(a: &ArrayView1<A>, b: &ArrayView1<B>) -> LinalgResult<C>
where
A: Clone + Debug + ToPrimitive + Copy,
B: Clone + Debug + ToPrimitive + Copy,
C: Clone + Zero + NumCast + Debug,
H: Float + Clone + NumCast + Debug + ToPrimitive,
{
if a.len() != b.len() {
return Err(LinalgError::ShapeError(format!(
"Vector dimensions must match for dot product: {} vs {}",
a.len(),
b.len()
)));
}
if a.len() <= 4 {
let a_high = convert::<A, H>(a);
let b_high = convert::<B, H>(b);
let mut sum = H::zero();
for i in 0..a.len() {
sum = sum + a_high[i] * b_high[i];
}
return C::from(sum).ok_or_else(|| {
LinalgError::ComputationError("Failed to convert dot product result".to_string())
});
}
let a_high = convert::<A, H>(a);
let b_high = convert::<B, H>(b);
let mut sum = H::zero(); let mut c = H::zero();
for i in 0..a.len() {
let product = a_high[i] * b_high[i];
let y = product - c; let t = sum + y; c = (t - sum) - y; sum = t; }
C::from(sum).ok_or_else(|| {
LinalgError::ComputationError("Failed to convert dot product result".to_string())
})
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_mixed_precision_matvec_f32() {
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]];
let x = array![0.5f32, 0.5f32];
let result = mixed_precision_matvec_f32::<f32, f32, f32, f64>(&a.view(), &x.view())
.expect("Operation failed");
assert_eq!(result.len(), 2);
assert_relative_eq!(result[0], 1.5f32, epsilon = 1e-6);
assert_relative_eq!(result[1], 3.5f32, epsilon = 1e-6);
}
#[test]
fn test_mixed_precision_matmul_f32_basic_small() {
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]];
let b = array![[5.0f32, 6.0f32], [7.0f32, 8.0f32]];
let result = mixed_precision_matmul_f32_basic::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result[[0, 0]], 19.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[0, 1]], 22.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[1, 0]], 43.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[1, 1]], 50.0f32, epsilon = 1e-5);
}
#[test]
fn test_mixed_precision_matmul_f32_basic_large() {
let size = 64;
let a = Array2::<f32>::ones((size, size));
let b = Array2::<f32>::ones((size, size));
let result = mixed_precision_matmul_f32_basic::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_eq!(result.shape(), &[size, size]);
for i in 0..size {
for j in 0..size {
assert_relative_eq!(result[[i, j]], size as f32, epsilon = 1e-4);
}
}
}
#[test]
fn test_mixed_precision_dot_f32() {
let a = array![1.0f32, 2.0f32, 3.0f32];
let b = array![4.0f32, 5.0f32, 6.0f32];
let result = mixed_precision_dot_f32::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_relative_eq!(result, 32.0f32, epsilon = 1e-6);
}
#[test]
fn test_mixed_precision_dot_f32_short() {
let a = array![1.0f32, 2.0f32];
let b = array![3.0f32, 4.0f32];
let result = mixed_precision_dot_f32::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_relative_eq!(result, 11.0f32, epsilon = 1e-6);
}
#[test]
fn test_mixed_precision_dot_f32_precision() {
let a = array![1e-6f32, 2e6f32, 3e-6f32];
let b = array![4e6f32, 5e-6f32, 6e6f32];
let result = mixed_precision_dot_f32::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_relative_eq!(result, 32.0f32, epsilon = 1e-3);
}
#[test]
fn test_dimension_mismatch_errors() {
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]];
let x = array![1.0f32, 2.0f32, 3.0f32];
let result = mixed_precision_matvec_f32::<f32, f32, f32, f64>(&a.view(), &x.view());
assert!(result.is_err());
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]]; let b = array![[1.0f32, 2.0f32, 3.0f32]];
let result = mixed_precision_matmul_f32_basic::<f32, f32, f32, f64>(&a.view(), &b.view());
assert!(result.is_err());
let a = array![1.0f32, 2.0f32];
let b = array![1.0f32, 2.0f32, 3.0f32];
let result = mixed_precision_dot_f32::<f32, f32, f32, f64>(&a.view(), &b.view());
assert!(result.is_err());
}
}