use super::Context;
use super::Operation;
use crate::ffi::*;
use crate::{Error, API};
impl API {
#[allow(clippy::many_single_char_names)]
#[allow(clippy::too_many_arguments)]
pub fn gemm(
context: &Context,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: *mut f32,
a: *mut f32,
lda: i32,
b: *mut f32,
ldb: i32,
beta: *mut f32,
c: *mut f32,
ldc: i32,
) -> Result<(), Error> {
unsafe {
Self::ffi_sgemm(
*context.id_c(),
transa.as_c(),
transb.as_c(),
m,
n,
k,
alpha,
a,
lda,
b,
ldb,
beta,
c,
ldc,
)
}
}
#[allow(clippy::many_single_char_names)]
#[allow(clippy::too_many_arguments)]
unsafe fn ffi_sgemm(
handle: cublasHandle_t,
transa: cublasOperation_t,
transb: cublasOperation_t,
m: i32,
n: i32,
k: i32,
alpha: *mut f32,
a: *mut f32,
lda: i32,
b: *mut f32,
ldb: i32,
beta: *mut f32,
c: *mut f32,
ldc: i32,
) -> Result<(), Error> {
match cublasSgemm_v2(
handle, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE => {
Err(Error::InvalidValue("m, n, or k < 0"))
}
cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => Err(Error::ArchMismatch),
cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => Err(Error::ExecutionFailed),
status => Err(Error::Unknown(
"Unable to calculate axpy (alpha * x + y).",
status as i32 as u64,
)),
}
}
}
#[cfg(test)]
mod test {
use crate::api::context::Context;
use crate::api::enums::PointerMode;
use crate::chore::*;
use crate::co::tensor::SharedTensor;
use crate::ffi::*;
use crate::API;
#[test]
fn use_cuda_memory_for_gemm() {
test_setup();
let native = get_native_backend();
let cuda = get_cuda_backend();
let alpha = filled_tensor(&native, 1, 1f32);
let beta = filled_tensor(&native, 1, 0f32);
let mut a = SharedTensor::<f32>::new(&vec![3, 2]);
write_to_memory(
a.write_only(native.device()).unwrap(),
&[2f32, 5f32, 2f32, 5f32, 2f32, 5f32],
);
let mut b = SharedTensor::<f32>::new(&vec![2, 3]);
write_to_memory(
b.write_only(native.device()).unwrap(),
&[4f32, 1f32, 1f32, 4f32, 1f32, 1f32],
);
let mut c = SharedTensor::<f32>::new(&vec![3, 3]);
{
let transa = cublasOperation_t::CUBLAS_OP_N;
let transb = cublasOperation_t::CUBLAS_OP_N;
let m = 3;
let n = 3;
let k = 2;
let lda = 2;
let ldb = 3;
let ldc = 3;
let cuda_mem_alpha = alpha.read(cuda.device()).unwrap();
let cuda_mem_beta = beta.read(cuda.device()).unwrap();
let cuda_mem_a = a.read(cuda.device()).unwrap();
let cuda_mem_b = b.read(cuda.device()).unwrap();
let cuda_mem_c = c.write_only(cuda.device()).unwrap();
let mut ctx = Context::new().unwrap();
ctx.set_pointer_mode(PointerMode::Device).unwrap();
unsafe {
let alpha_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_alpha.id_c());
let beta_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_beta.id_c());
let a_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_a.id_c());
let b_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_b.id_c());
let c_addr = ::std::mem::transmute::<u64, *mut f32>(*cuda_mem_c.id_c());
API::ffi_sgemm(
*ctx.id_c(),
transa,
transb,
m,
n,
k,
alpha_addr,
b_addr,
ldb,
a_addr,
lda,
beta_addr,
c_addr,
ldc,
)
.unwrap();
}
}
let native_c = c.read(native.device()).unwrap();
assert_eq!(
&[28f32, 7f32, 7f32, 28f32, 7f32, 7f32, 28f32, 7f32, 7f32],
native_c.as_slice::<f32>()
);
test_teardown();
}
}