use scirs2_core::gpu::{GpuBackend, GpuContext, GpuError};
use scirs2_core::ndarray::{Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
use super::types::BatchMatmulResult;
const GPU_MATMUL_THRESHOLD: usize = 64;
pub fn gpu_matmul<F>(
ctx: Option<&GpuContext>,
a: &ArrayView2<F>,
b: &ArrayView2<F>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let (m, k_a) = a.dim();
let (k_b, n) = b.dim();
if k_a != k_b {
return Err(LinalgError::DimensionError(format!(
"Matrix multiplication dimension mismatch: A is {}x{} but B is {}x{}",
m, k_a, k_b, n
)));
}
let k = k_a;
let use_gpu = ctx.is_some()
&& m >= GPU_MATMUL_THRESHOLD
&& n >= GPU_MATMUL_THRESHOLD
&& k >= GPU_MATMUL_THRESHOLD;
if use_gpu {
if let Some(gpu_ctx) = ctx {
match gpu_matmul_impl(gpu_ctx, a, b, m, k, n) {
Ok(result) => return Ok(result),
Err(_) => {
}
}
}
}
cpu_matmul(a, b)
}
pub fn gpu_batched_matmul<F>(
ctx: Option<&GpuContext>,
a_batch: &[Array2<F>],
b_batch: &[Array2<F>],
) -> LinalgResult<BatchMatmulResult<F>>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
if a_batch.len() != b_batch.len() {
return Err(LinalgError::DimensionError(format!(
"Batch size mismatch: {} left matrices vs {} right matrices",
a_batch.len(),
b_batch.len()
)));
}
if a_batch.is_empty() {
return Ok(BatchMatmulResult {
results: Vec::new(),
});
}
for (i, (a, b)) in a_batch.iter().zip(b_batch.iter()).enumerate() {
let (_, k_a) = a.dim();
let (k_b, _) = b.dim();
if k_a != k_b {
return Err(LinalgError::DimensionError(format!(
"Batch element {}: dimension mismatch A({}x{}) * B({}x{})",
i,
a.nrows(),
k_a,
k_b,
b.ncols()
)));
}
}
if let Some(gpu_ctx) = ctx {
match gpu_batched_matmul_impl(gpu_ctx, a_batch, b_batch) {
Ok(results) => return Ok(results),
Err(_) => {
}
}
}
let mut results = Vec::with_capacity(a_batch.len());
for (a, b) in a_batch.iter().zip(b_batch.iter()) {
results.push(cpu_matmul(&a.view(), &b.view())?);
}
Ok(BatchMatmulResult { results })
}
fn gpu_matmul_impl<F>(
ctx: &GpuContext,
a: &ArrayView2<F>,
b: &ArrayView2<F>,
m: usize,
k: usize,
n: usize,
) -> Result<Array2<F>, LinalgError>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let a_flat: Vec<f64> = a.iter().map(|v| v.to_f64().unwrap_or(0.0)).collect();
let b_flat: Vec<f64> = b.iter().map(|v| v.to_f64().unwrap_or(0.0)).collect();
let gpu_a = ctx.create_buffer_from_slice(&a_flat);
let gpu_b = ctx.create_buffer_from_slice(&b_flat);
let gpu_c = ctx
.gemm(&gpu_a, &gpu_b, m, k, n)
.map_err(|e| LinalgError::ComputationError(format!("GPU GEMM failed: {}", e)))?;
let c_flat = gpu_c.to_vec();
let c_data: Vec<F> = c_flat
.iter()
.map(|&v| F::from(v).unwrap_or_else(F::zero))
.collect();
Array2::from_shape_vec((m, n), c_data)
.map_err(|e| LinalgError::ShapeError(format!("Failed to reshape GPU result: {}", e)))
}
fn gpu_batched_matmul_impl<F>(
ctx: &GpuContext,
a_batch: &[Array2<F>],
b_batch: &[Array2<F>],
) -> Result<BatchMatmulResult<F>, LinalgError>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let mut results = Vec::with_capacity(a_batch.len());
for (a, b) in a_batch.iter().zip(b_batch.iter()) {
let (m, k) = a.dim();
let (_, n) = b.dim();
let a_flat: Vec<f64> = a.iter().map(|v| v.to_f64().unwrap_or(0.0)).collect();
let b_flat: Vec<f64> = b.iter().map(|v| v.to_f64().unwrap_or(0.0)).collect();
let gpu_a = ctx.create_buffer_from_slice(&a_flat);
let gpu_b = ctx.create_buffer_from_slice(&b_flat);
let gpu_c = ctx.gemm(&gpu_a, &gpu_b, m, k, n).map_err(|e| {
LinalgError::ComputationError(format!("GPU batched GEMM failed: {}", e))
})?;
let c_flat = gpu_c.to_vec();
let c_data: Vec<F> = c_flat
.iter()
.map(|&v| F::from(v).unwrap_or_else(F::zero))
.collect();
let result = Array2::from_shape_vec((m, n), c_data).map_err(|e| {
LinalgError::ShapeError(format!("Failed to reshape batch result: {}", e))
})?;
results.push(result);
}
Ok(BatchMatmulResult { results })
}
fn cpu_matmul<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut c = Array2::<F>::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut sum = F::zero();
for p in 0..k {
sum += a[[i, p]] * b[[p, j]];
}
c[[i, j]] = sum;
}
}
Ok(c)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_cpu_matmul_basic() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let b = array![[5.0_f64, 6.0], [7.0, 8.0]];
let c = cpu_matmul(&a.view(), &b.view()).expect("matmul failed");
assert!((c[[0, 0]] - 19.0).abs() < 1e-10);
assert!((c[[0, 1]] - 22.0).abs() < 1e-10);
assert!((c[[1, 0]] - 43.0).abs() < 1e-10);
assert!((c[[1, 1]] - 50.0).abs() < 1e-10);
}
#[test]
fn test_gpu_matmul_dimension_mismatch() {
let a = array![[1.0_f64, 2.0, 3.0]];
let b = array![[1.0_f64], [2.0]];
let result = gpu_matmul(None, &a.view(), &b.view());
assert!(result.is_err());
}
#[test]
fn test_gpu_matmul_cpu_fallback() {
let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
let b = array![[5.0_f64, 6.0], [7.0, 8.0]];
let c = gpu_matmul(None, &a.view(), &b.view()).expect("fallback failed");
assert!((c[[0, 0]] - 5.0).abs() < 1e-10);
assert!((c[[1, 1]] - 8.0).abs() < 1e-10);
}
#[test]
fn test_gpu_matmul_rectangular() {
let a = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]]; let b = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]]; let c = gpu_matmul(None, &a.view(), &b.view()).expect("rectangular matmul failed");
assert_eq!(c.dim(), (2, 2));
assert!((c[[0, 0]] - 22.0).abs() < 1e-10);
assert!((c[[0, 1]] - 28.0).abs() < 1e-10);
assert!((c[[1, 0]] - 49.0).abs() < 1e-10);
assert!((c[[1, 1]] - 64.0).abs() < 1e-10);
}
#[test]
fn test_gpu_batched_matmul_empty() {
let result: LinalgResult<BatchMatmulResult<f64>> = gpu_batched_matmul(None, &[], &[]);
let batch = result.expect("empty batch failed");
assert!(batch.results.is_empty());
}
#[test]
fn test_gpu_batched_matmul_size_mismatch() {
let a = vec![array![[1.0_f64]]];
let b: Vec<Array2<f64>> = vec![];
let result = gpu_batched_matmul(None, &a, &b);
assert!(result.is_err());
}
#[test]
fn test_gpu_batched_matmul_cpu_fallback() {
let a1 = array![[1.0_f64, 2.0], [3.0, 4.0]];
let b1 = array![[5.0_f64, 6.0], [7.0, 8.0]];
let a2 = array![[1.0_f64, 0.0], [0.0, 1.0]];
let b2 = array![[9.0_f64, 10.0], [11.0, 12.0]];
let batch = gpu_batched_matmul(None, &[a1, a2], &[b1, b2]).expect("batch failed");
assert_eq!(batch.results.len(), 2);
assert!((batch.results[0][[0, 0]] - 19.0).abs() < 1e-10);
assert!((batch.results[1][[0, 0]] - 9.0).abs() < 1e-10);
assert!((batch.results[1][[1, 1]] - 12.0).abs() < 1e-10);
}
#[test]
fn test_gpu_matmul_f32() {
let a = array![[1.0_f32, 2.0], [3.0, 4.0]];
let b = array![[5.0_f32, 6.0], [7.0, 8.0]];
let c = gpu_matmul(None, &a.view(), &b.view()).expect("f32 matmul failed");
assert!((c[[0, 0]] - 19.0).abs() < 1e-4);
assert!((c[[1, 1]] - 50.0).abs() < 1e-4);
}
#[test]
fn test_gpu_matmul_with_cpu_context() {
let ctx = GpuContext::new(GpuBackend::Cpu);
match ctx {
Ok(gpu_ctx) => {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let b = array![[5.0_f64, 6.0], [7.0, 8.0]];
let c = gpu_matmul(Some(&gpu_ctx), &a.view(), &b.view())
.expect("matmul with ctx failed");
assert!((c[[0, 0]] - 19.0).abs() < 1e-10);
}
Err(_) => {
}
}
}
}