use oxicuda_blas::{level3::gemm_api::gemm, Layout, MatrixDesc, MatrixDescMut, Transpose};
use oxicuda_memory::DeviceBuffer;
use crate::context::CudaContext;
use crate::error::CudaDispatchError;
pub fn cuda_matmul(
ctx: &CudaContext,
a_data: &[f32],
b_data: &[f32],
m: usize,
k: usize,
n: usize,
) -> Result<Vec<f32>, CudaDispatchError> {
let mut d_a: DeviceBuffer<f32> = DeviceBuffer::alloc(m * k)?;
d_a.copy_from_host(a_data)?;
let mut d_b: DeviceBuffer<f32> = DeviceBuffer::alloc(k * n)?;
d_b.copy_from_host(b_data)?;
let mut d_c: DeviceBuffer<f32> = DeviceBuffer::zeroed(m * n)?;
let desc_a = MatrixDesc::<f32>::from_buffer(&d_a, m as u32, k as u32, Layout::RowMajor)
.map_err(|e| CudaDispatchError::Blas(e.to_string()))?;
let desc_b = MatrixDesc::<f32>::from_buffer(&d_b, k as u32, n as u32, Layout::RowMajor)
.map_err(|e| CudaDispatchError::Blas(e.to_string()))?;
let mut desc_c =
MatrixDescMut::<f32>::from_buffer(&mut d_c, m as u32, n as u32, Layout::RowMajor)
.map_err(|e| CudaDispatchError::Blas(e.to_string()))?;
let blas_handle = ctx.dnn.blas();
gemm(
blas_handle,
Transpose::NoTrans,
Transpose::NoTrans,
1.0_f32,
&desc_a,
&desc_b,
0.0_f32,
&mut desc_c,
)
.map_err(|e| CudaDispatchError::Blas(e.to_string()))?;
ctx.dnn
.stream()
.synchronize()
.map_err(CudaDispatchError::Driver)?;
let mut out = vec![0.0_f32; m * n];
d_c.copy_to_host(&mut out)?;
Ok(out)
}