use hodu_cuda_kernels::{compat::*, kernel::Kernels, kernels::*};
fn device() -> Arc<cudarc::driver::CudaContext> {
cudarc::driver::CudaContext::new(0).unwrap()
}
fn kernels() -> Kernels {
Kernels::new()
}
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t * b) / b).collect()
}
#[test]
fn dot_f32_simple() {
let kernels = kernels();
let device = device();
let stream = device.default_stream();
let lhs = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let rhs = vec![7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0];
let lhs_dev = stream.memcpy_stod(&lhs).unwrap();
let rhs_dev = stream.memcpy_stod(&rhs).unwrap();
let m = 2; let k = 3; let n = 2;
let mut output: cudarc::driver::CudaSlice<f32> = unsafe { stream.alloc(m * n).unwrap() };
let metadata = vec![m, k, n, 3, 1, 2, 1, 0, 0];
call_ops_dot(dot::F32, &kernels, &device, &lhs_dev, &rhs_dev, &mut output, &metadata).unwrap();
let mut results = vec![0.0f32; m * n];
stream.memcpy_dtoh(&output, &mut results).unwrap();
assert_eq!(approx(results, 4), vec![58.0, 64.0, 139.0, 154.0]);
}
#[test]
fn matmul_f32_2d() {
let kernels = kernels();
let device = device();
let stream = device.default_stream();
let lhs = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let rhs = vec![7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0];
let lhs_dev = stream.memcpy_stod(&lhs).unwrap();
let rhs_dev = stream.memcpy_stod(&rhs).unwrap();
let m = 2;
let k = 3;
let n = 2;
let num_els = m * n;
let mut output: cudarc::driver::CudaSlice<f32> = unsafe { stream.alloc(num_els).unwrap() };
let lhs_ndim = 2;
let rhs_ndim = 2;
let batch_ndim = 0;
let metadata = vec![
num_els, lhs_ndim, rhs_ndim, batch_ndim, m, k, k, n, k, 1, n, 1, 0, 0, m, k, n, ];
call_ops_matmul(
matmul::F32,
&kernels,
&device,
&lhs_dev,
&rhs_dev,
&mut output,
&metadata,
)
.unwrap();
let mut results = vec![0.0f32; num_els];
stream.memcpy_dtoh(&output, &mut results).unwrap();
assert_eq!(approx(results, 4), vec![58.0, 64.0, 139.0, 154.0]);
}
#[test]
fn matmul_f32_square() {
let kernels = kernels();
let device = device();
let stream = device.default_stream();
let lhs = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let rhs = vec![9.0f32, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let lhs_dev = stream.memcpy_stod(&lhs).unwrap();
let rhs_dev = stream.memcpy_stod(&rhs).unwrap();
let m = 3;
let k = 3;
let n = 3;
let num_els = m * n;
let mut output: cudarc::driver::CudaSlice<f32> = unsafe { stream.alloc(num_els).unwrap() };
let lhs_ndim = 2;
let rhs_ndim = 2;
let batch_ndim = 0;
let metadata = vec![
num_els, lhs_ndim, rhs_ndim, batch_ndim, m, k, k, n, k, 1, n, 1, 0, 0, m, k, n,
];
call_ops_matmul(
matmul::F32,
&kernels,
&device,
&lhs_dev,
&rhs_dev,
&mut output,
&metadata,
)
.unwrap();
let mut results = vec![0.0f32; num_els];
stream.memcpy_dtoh(&output, &mut results).unwrap();
let expected = vec![30.0, 24.0, 18.0, 84.0, 69.0, 54.0, 138.0, 114.0, 90.0];
assert_eq!(approx(results, 4), expected);
}
#[test]
fn dot_f32_identity() {
let kernels = kernels();
let device = device();
let stream = device.default_stream();
let identity = vec![1.0f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
let identity_dev = stream.memcpy_stod(&identity).unwrap();
let m = 3;
let k = 3;
let n = 3;
let mut output: cudarc::driver::CudaSlice<f32> = unsafe { stream.alloc(m * n).unwrap() };
let metadata = vec![m, k, n, 3, 1, 3, 1, 0, 0];
call_ops_dot(
dot::F32,
&kernels,
&device,
&identity_dev,
&identity_dev,
&mut output,
&metadata,
)
.unwrap();
let mut results = vec![0.0f32; m * n];
stream.memcpy_dtoh(&output, &mut results).unwrap();
assert_eq!(results, identity);
}
#[test]
fn matmul_f32_batch() {
let kernels = kernels();
let device = device();
let stream = device.default_stream();
let lhs = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
let rhs = vec![1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0];
let lhs_dev = stream.memcpy_stod(&lhs).unwrap();
let rhs_dev = stream.memcpy_stod(&rhs).unwrap();
let batch = 2;
let m = 2;
let k = 3;
let n = 2;
let num_els = batch * m * n;
let mut output: cudarc::driver::CudaSlice<f32> = unsafe { stream.alloc(num_els).unwrap() };
let lhs_ndim = 3;
let rhs_ndim = 2;
let batch_ndim = 1;
let metadata = vec![
num_els,
lhs_ndim,
rhs_ndim,
batch_ndim,
batch,
m,
k, k,
n, batch, m * k,
k,
1, n,
1, 0,
0, m,
k,
n, ];
call_ops_matmul(
matmul::F32,
&kernels,
&device,
&lhs_dev,
&rhs_dev,
&mut output,
&metadata,
)
.unwrap();
let mut results = vec![0.0f32; num_els];
stream.memcpy_dtoh(&output, &mut results).unwrap();
assert_eq!(approx(results, 4), vec![4.0, 5.0, 10.0, 11.0, 16.0, 17.0, 22.0, 23.0]);
}