use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign, NumCast, ToPrimitive, Zero};
use std::fmt::Debug;
use super::conversions::convert_2d;
use crate::error::{LinalgError, LinalgResult};
#[cfg(feature = "parallel")]
#[allow(dead_code)]
pub fn mixed_precision_matmul_f64_parallel<A, B, C, H>(
a: &ArrayView2<A>,
b: &ArrayView2<B>,
) -> LinalgResult<Array2<C>>
where
A: Clone + Debug + ToPrimitive + Copy + Sync,
B: Clone + Debug + ToPrimitive + Copy + Sync,
C: Clone + Zero + NumCast + Debug + Send,
H: Float + Clone + NumCast + Debug + ToPrimitive + NumAssign + Zero + Send + Sync,
{
use scirs2_core::ndarray::Zip;
let ashape = a.shape();
let bshape = b.shape();
if ashape[1] != bshape[0] {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions incompatible for multiplication: {}x{} and {}x{}",
ashape[0], ashape[1], bshape[0], bshape[1]
)));
}
let m = ashape[0];
let n = bshape[1];
let k = ashape[1];
let a_high = convert_2d::<A, H>(a);
let b_high = convert_2d::<B, H>(b);
if m <= 32 && n <= 32 && k <= 32 {
let mut c_high = Array2::<H>::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut sum = H::zero();
for l in 0..k {
sum += a_high[[i, l]] * b_high[[l, j]];
}
c_high[[i, j]] = sum;
}
}
let mut c = Array2::<C>::zeros((m, n));
for i in 0..m {
for j in 0..n {
c[[i, j]] = C::from(c_high[[i, j]]).unwrap_or_else(|| C::zero());
}
}
return Ok(c);
}
if m <= 512 && n <= 512 && k <= 512 {
const BLOCK_SIZE: usize = 32;
let mut c_high = Array2::<H>::zeros((m, n));
let block_m = m.div_ceil(BLOCK_SIZE);
let block_n = n.div_ceil(BLOCK_SIZE);
let block_k = k.div_ceil(BLOCK_SIZE);
for bi in 0..block_m {
let i_start = bi * BLOCK_SIZE;
let i_end = std::cmp::min(i_start + BLOCK_SIZE, m);
for bj in 0..block_n {
let j_start = bj * BLOCK_SIZE;
let j_end = std::cmp::min(j_start + BLOCK_SIZE, n);
for i in i_start..i_end {
for j in j_start..j_end {
c_high[[i, j]] = H::zero();
for bk in 0..block_k {
let k_start = bk * BLOCK_SIZE;
let k_end = std::cmp::min(k_start + BLOCK_SIZE, k);
let mut sum = H::zero();
for l in k_start..k_end {
sum += a_high[[i, l]] * b_high[[l, j]];
}
c_high[[i, j]] += sum;
}
}
}
}
}
let mut c = Array2::<C>::zeros((m, n));
Zip::from(&mut c)
.and(&c_high)
.par_for_each(|c_val, &h_val| {
*c_val = C::from(h_val).unwrap_or_else(|| C::zero());
});
return Ok(c);
}
let mut c_high = Array2::<H>::zeros((m, n));
Zip::from(c_high.rows_mut())
.and(a_high.rows())
.par_for_each(|mut c_row, a_row| {
for (j, c_val) in c_row.iter_mut().enumerate() {
let mut sum = H::zero();
for (l, &a_val) in a_row.iter().enumerate() {
sum += a_val * b_high[[l, j]];
}
*c_val = sum;
}
});
let mut c = Array2::<C>::zeros((m, n));
Zip::from(&mut c)
.and(&c_high)
.par_for_each(|c_val, &h_val| {
*c_val = C::from(h_val).unwrap_or_else(|| C::zero());
});
Ok(c)
}
#[cfg(not(feature = "parallel"))]
#[allow(dead_code)]
pub fn mixed_precision_matmul_f64_serial<A, B, C, H>(
a: &ArrayView2<A>,
b: &ArrayView2<B>,
) -> LinalgResult<Array2<C>>
where
A: Clone + Debug + ToPrimitive + Copy,
B: Clone + Debug + ToPrimitive + Copy,
C: Clone + Zero + NumCast + Debug,
H: Float + Clone + NumCast + Debug + ToPrimitive + NumAssign + Zero,
{
let ashape = a.shape();
let bshape = b.shape();
if ashape[1] != bshape[0] {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions incompatible for multiplication: {}x{} and {}x{}",
ashape[0], ashape[1], bshape[0], bshape[1]
)));
}
let m = ashape[0];
let n = bshape[1];
let k = ashape[1];
let a_high = convert_2d::<A, H>(a);
let b_high = convert_2d::<B, H>(b);
if m <= 32 && n <= 32 && k <= 32 {
let mut c_high = Array2::<H>::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut sum = H::zero();
for l in 0..k {
sum += a_high[[i, l]] * b_high[[l, j]];
}
c_high[[i, j]] = sum;
}
}
let mut c = Array2::<C>::zeros((m, n));
for i in 0..m {
for j in 0..n {
c[[i, j]] = C::from(c_high[[i, j]]).unwrap_or_else(|| C::zero());
}
}
return Ok(c);
}
const BLOCK_SIZE: usize = 32;
let mut c_high = Array2::<H>::zeros((m, n));
let block_m = m.div_ceil(BLOCK_SIZE);
let block_n = n.div_ceil(BLOCK_SIZE);
let block_k = k.div_ceil(BLOCK_SIZE);
for bi in 0..block_m {
let i_start = bi * BLOCK_SIZE;
let i_end = std::cmp::min(i_start + BLOCK_SIZE, m);
for bj in 0..block_n {
let j_start = bj * BLOCK_SIZE;
let j_end = std::cmp::min(j_start + BLOCK_SIZE, n);
for i in i_start..i_end {
for j in j_start..j_end {
c_high[[i, j]] = H::zero();
}
}
for bk in 0..block_k {
let k_start = bk * BLOCK_SIZE;
let k_end = std::cmp::min(k_start + BLOCK_SIZE, k);
for i in i_start..i_end {
for j in j_start..j_end {
let mut sum = H::zero();
for l in k_start..k_end {
sum += a_high[[i, l]] * b_high[[l, j]];
}
c_high[[i, j]] += sum;
}
}
}
}
}
let mut c = Array2::<C>::zeros((m, n));
for i in 0..m {
for j in 0..n {
c[[i, j]] = C::from(c_high[[i, j]]).unwrap_or_else(|| C::zero());
}
}
Ok(c)
}
#[allow(dead_code)]
pub fn mixed_precision_matmul_f64<A, B, C, H>(
a: &ArrayView2<A>,
b: &ArrayView2<B>,
) -> LinalgResult<Array2<C>>
where
A: Clone + Debug + ToPrimitive + Copy + Sync,
B: Clone + Debug + ToPrimitive + Copy + Sync,
C: Clone + Zero + NumCast + Debug + Send,
H: Float + Clone + NumCast + Debug + ToPrimitive + NumAssign + Zero + Send + Sync,
{
#[cfg(feature = "parallel")]
{
mixed_precision_matmul_f64_parallel::<A, B, C, H>(a, b)
}
#[cfg(not(feature = "parallel"))]
{
mixed_precision_matmul_f64_serial::<A, B, C, H>(a, b)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_mixed_precision_matmul_f64_small() {
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]];
let b = array![[5.0f32, 6.0f32], [7.0f32, 8.0f32]];
let result = mixed_precision_matmul_f64::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result[[0, 0]], 19.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[0, 1]], 22.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[1, 0]], 43.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[1, 1]], 50.0f32, epsilon = 1e-5);
}
#[test]
fn test_mixed_precision_matmul_f64_medium() {
let size = 64;
let a = Array2::<f32>::ones((size, size));
let b = Array2::<f32>::ones((size, size));
let result = mixed_precision_matmul_f64::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_eq!(result.shape(), &[size, size]);
for i in 0..size {
for j in 0..size {
assert_relative_eq!(result[[i, j]], size as f32, epsilon = 1e-4);
}
}
}
#[test]
fn test_mixed_precision_matmul_f64_large() {
let size = 128;
let a = Array2::<f32>::ones((size, size));
let b = Array2::<f32>::ones((size, size));
let result = mixed_precision_matmul_f64::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_eq!(result.shape(), &[size, size]);
for i in 0..10 {
for j in 0..10 {
assert_relative_eq!(result[[i, j]], size as f32, epsilon = 1e-4);
}
}
}
#[test]
fn test_mixed_precision_matmul_f64_precision() {
let a = array![[1e-6f32, 2e6f32], [3e-6f32, 4e6f32]];
let b = array![[5e6f32, 6e-6f32], [7e-6f32, 8e6f32]];
let result = mixed_precision_matmul_f64::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert!(result[[0, 0]].is_finite());
assert!(result[[0, 1]].is_finite());
assert!(result[[1, 0]].is_finite());
assert!(result[[1, 1]].is_finite());
}
#[test]
fn test_mixed_precision_matmul_f64_errors() {
let a = array![[1.0f32, 2.0f32], [3.0f32, 4.0f32]]; let b = array![[1.0f32, 2.0f32, 3.0f32]];
let result = mixed_precision_matmul_f64::<f32, f32, f32, f64>(&a.view(), &b.view());
assert!(result.is_err());
if let Err(LinalgError::ShapeError(_)) = result {
} else {
panic!("Expected ShapeError");
}
}
#[test]
fn test_mixed_precision_matmul_f64_rectangular() {
let a = array![[1.0f32, 2.0f32, 3.0f32], [4.0f32, 5.0f32, 6.0f32]]; let b = array![[7.0f32, 8.0f32], [9.0f32, 10.0f32], [11.0f32, 12.0f32]];
let result = mixed_precision_matmul_f64::<f32, f32, f32, f64>(&a.view(), &b.view())
.expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert_relative_eq!(result[[0, 0]], 58.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[0, 1]], 64.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[1, 0]], 139.0f32, epsilon = 1e-5);
assert_relative_eq!(result[[1, 1]], 154.0f32, epsilon = 1e-5);
}
}