use std::ffi::c_int;
use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Sub, SubAssign};
use super::{VectorIndex, VectorView, VectorViewMut};
use cudarc::cublas::sys as cublas;
use cudarc::cublas::CudaBlas;
use cudarc::driver::{
CudaFunction, CudaSlice, CudaView, CudaViewMut, DevicePtr, LaunchConfig, PushKernelArg,
};
use crate::{
Context, CudaContext, CudaMat, CudaType, DefaultDenseMatrix, IndexType, ScalarCuda, Scale,
Vector, VectorCommon,
};
extern "C" fn zero(_block_size: std::ffi::c_int) -> usize {
0
}
extern "C" fn norm_blk_size<T: ScalarCuda>(block_size: std::ffi::c_int) -> usize {
(block_size * std::mem::size_of::<T>() as c_int) as usize
}
extern "C" fn squared_norm_blk_size<T: ScalarCuda>(block_size: std::ffi::c_int) -> usize {
(block_size * std::mem::size_of::<T>() as c_int) as usize
}
extern "C" fn root_finding_blk_size<T: ScalarCuda>(block_size: std::ffi::c_int) -> usize {
((block_size * std::mem::size_of::<T>() as c_int)
+ (block_size * std::mem::size_of::<c_int>() as c_int)) as usize
}
impl CudaContext {
pub(crate) fn launch_config_2d(
&self,
nstates: u32,
nbatch: u32,
f: &CudaFunction,
) -> LaunchConfig {
let (_min_grid_size, block_size) = f
.occupancy_max_potential_block_size(zero, 0, 0, None)
.expect("Failed to get occupancy max potential block size");
let grid_x = nstates.div_ceil(block_size);
LaunchConfig {
grid_dim: (grid_x, nbatch, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
}
}
pub(crate) fn launch_config_2d_reduce(
&self,
nstates: u32,
nbatch: u32,
f: &CudaFunction,
smem_size_f: extern "C" fn(block_size: std::ffi::c_int) -> usize,
) -> LaunchConfig {
let (_min_grid_size, block_size) = f
.occupancy_max_potential_block_size(smem_size_f, 0, 0, None)
.expect("Failed to get occupancy max potential block size");
let highest_bit_set_idx = 31 - (block_size | 1).leading_zeros();
let block_size = (1 << highest_bit_set_idx) & block_size;
let grid_x = nstates.div_ceil(block_size);
LaunchConfig {
grid_dim: (grid_x, nbatch, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: smem_size_f(block_size as i32) as u32,
}
}
fn lk_norm<T: ScalarCuda>(&self, x: &CudaSlice<T>, nstates: usize, nbatch: usize, k: i32) -> T {
let nstates_u32 = nstates as u32;
let nbatch_u32 = nbatch as u32;
let kernel_name = if k == 2 { "vec_norm" } else { "vec_norm_lk" };
let f = self.function::<T>(kernel_name);
let config = self.launch_config_2d_reduce(nstates_u32, nbatch_u32, &f, norm_blk_size::<T>);
let blocks_per_batch = config.grid_dim.0 as usize;
let total_blocks = blocks_per_batch * nbatch;
let mut partial_sums = unsafe {
self.stream
.alloc::<T>(total_blocks)
.expect("Failed to allocate memory for partial sums")
};
let mut build = self.stream.launch_builder(&f);
let x_stride = nstates as i32;
build
.arg(x)
.arg(&nstates_u32)
.arg(&nbatch_u32)
.arg(&x_stride);
if k != 2 {
build.arg(&k);
}
build.arg(&mut partial_sums);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
let partial_sums = self
.stream
.clone_dtoh(&partial_sums)
.expect("Failed to copy data from device to host");
let mut max_norm = T::zero();
for b in 0..nbatch {
let mut batch_sum = T::zero();
for i in 0..blocks_per_batch {
batch_sum += partial_sums[b * blocks_per_batch + i];
}
let norm = if k == 2 {
batch_sum.sqrt()
} else {
batch_sum.pow(T::one() / T::from_f64(k as f64).unwrap())
};
if norm > max_norm {
max_norm = norm;
}
}
max_norm
}
}
#[derive(Debug, Clone)]
pub struct CudaVec<T: ScalarCuda> {
pub(crate) data: CudaSlice<T>,
pub(crate) context: CudaContext,
}
#[derive(Debug, Clone)]
pub struct CudaIndex {
pub(crate) data: CudaSlice<c_int>,
pub(crate) context: CudaContext,
}
#[derive(Debug)]
pub struct CudaVecRef<'a, T: ScalarCuda> {
pub(crate) data: CudaView<'a, T>,
pub(crate) context: CudaContext,
pub(crate) nstates: IndexType,
pub(crate) col_offset: IndexType,
}
#[derive(Debug)]
pub struct CudaVecMut<'a, T: ScalarCuda> {
pub(crate) data: CudaViewMut<'a, T>,
pub(crate) context: CudaContext,
pub(crate) nstates: IndexType,
pub(crate) col_offset: IndexType,
}
impl<T: ScalarCuda> DefaultDenseMatrix for CudaVec<T> {
type M = CudaMat<T>;
}
impl<T: ScalarCuda> CudaVec<T> {
pub(crate) fn stride(&self) -> IndexType {
self.len()
}
pub(crate) fn kview(&self) -> CudaView<'_, T> {
self.data.as_view()
}
pub(crate) fn kview_mut(&mut self) -> CudaViewMut<'_, T> {
self.data.as_view_mut()
}
}
impl<'a, T: ScalarCuda> CudaVecRef<'a, T> {
pub(crate) fn stride(&self) -> IndexType {
self.data.len() as IndexType / self.context.nbatch()
}
pub(crate) fn len(&self) -> IndexType {
self.nstates
}
pub(crate) fn kview(&self) -> CudaView<'_, T> {
self.data.slice(self.col_offset..)
}
}
impl<'a, T: ScalarCuda> CudaVecMut<'a, T> {
pub(crate) fn stride(&self) -> IndexType {
self.data.len() as IndexType / self.context.nbatch()
}
pub(crate) fn len(&self) -> IndexType {
self.nstates
}
pub(crate) fn kview(&self) -> CudaView<'_, T> {
self.data.slice(self.col_offset..)
}
pub(crate) fn kview_mut(&mut self) -> CudaViewMut<'_, T> {
self.data.slice_mut(self.col_offset..)
}
pub(crate) fn copy_from_ref(&mut self, other: &CudaVecRef<'_, T>) {
let nbatch = self.context.nbatch();
let other_nbatch = other.context.nbatch();
self.context
.assert_compatible_nbatch(other_nbatch, "copy_from_view");
let nstates_u32 = self.nstates as u32;
if nstates_u32 == 0 {
return;
}
let nbatch_u32 = nbatch as u32;
let other_nbatch_i32 = other_nbatch as i32;
let f = self.context.function::<T>("vec_copy");
let config = self.context.launch_config_2d(nstates_u32, nbatch_u32, &f);
let self_stride = self.stride() as i32;
let rhs_stride = other.stride() as i32;
let col_offset = self.col_offset;
let other_col_offset = other.col_offset;
let mut build = self.context.stream.launch_builder(&f);
let mut self_data = self.data.slice_mut(col_offset..);
let rhs_data = other.data.slice(other_col_offset..);
build
.arg(&mut self_data)
.arg(&rhs_data)
.arg(&nstates_u32)
.arg(&self_stride)
.arg(&rhs_stride)
.arg(&other_nbatch_i32);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
}
macro_rules! impl_vector_common {
($vec:ty, $con:ty, $in:ty) => {
impl<T: ScalarCuda> VectorCommon for $vec {
type T = T;
type C = $con;
type Inner = $in;
fn inner(&self) -> &Self::Inner {
&self.data
}
}
};
}
macro_rules! impl_vector_common_ref {
($vec:ty, $con:ty, $in:ty) => {
impl<'a, T: ScalarCuda> VectorCommon for $vec {
type T = T;
type C = $con;
type Inner = $in;
fn inner(&self) -> &Self::Inner {
&self.data
}
}
};
}
impl_vector_common!(CudaVec<T>, CudaContext, CudaSlice<T>);
impl_vector_common_ref!(CudaVecRef<'a, T>, CudaContext, CudaView<'a, T>);
impl_vector_common_ref!(CudaVecMut<'a, T>, CudaContext, CudaViewMut<'a, T>);
macro_rules! impl_mul_scalar {
([$($g:tt)*], $lhs:ty, $out:ty) => {
impl<$($g)*> Mul<Scale<T>> for $lhs {
type Output = $out;
fn mul(mut self, rhs: Scale<T>) -> Self::Output {
let ctx = self.context.clone();
let f = ctx.function::<T>("vec_mul_assign_scalar");
let nbatch = ctx.nbatch();
let nstates = self.len() as u32;
if nstates == 0 {
return self;
}
let nbatch_u32 = nbatch as u32;
let stride_i32 = self.stride() as i32;
let scalar = rhs.value();
{
let mut build = ctx.stream.launch_builder(&f);
let mut data = self.kview_mut();
build
.arg(&mut data)
.arg(&scalar)
.arg(&nstates)
.arg(&nbatch_u32)
.arg(&stride_i32);
let config = ctx.launch_config_2d(nstates, nbatch_u32, &f);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
self
}
}
};
}
macro_rules! impl_mul_scalar_alloc {
([$($g:tt)*], $lhs:ty, $out:ty) => {
impl<$($g)*> Mul<Scale<T>> for $lhs {
type Output = $out;
fn mul(self, rhs: Scale<T>) -> Self::Output {
let ctx = self.context.clone();
let nbatch = ctx.nbatch();
let nstates = self.len();
let mut ret = Self::Output::zeros(nstates, ctx.clone());
let f = ctx.function::<T>("vec_mul_scalar");
let nstates_u32 = nstates as u32;
if nstates_u32 == 0 {
return ret;
}
let nbatch_u32 = nbatch as u32;
let src_stride = self.stride() as i32;
let src_nbatch = nbatch as i32;
let ret_stride = nstates as i32;
let scalar = rhs.value();
{
let mut build = ctx.stream.launch_builder(&f);
let data = self.kview();
build
.arg(&data)
.arg(&scalar)
.arg(&mut ret.data)
.arg(&nstates_u32)
.arg(&ret_stride)
.arg(&src_stride)
.arg(&src_nbatch);
let config = ctx.launch_config_2d(nstates_u32, nbatch_u32, &f);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
ret
}
}
};
}
macro_rules! impl_div_scalar {
([$($g:tt)*], $lhs:ty, $out:ty) => {
impl<$($g)*> Div<Scale<T>> for $lhs {
type Output = $out;
fn div(self, rhs: Scale<T>) -> Self::Output {
let inv_rhs: T = T::one() / rhs.value();
self.mul(Scale(inv_rhs))
}
}
};
}
macro_rules! impl_mul_assign_scalar {
([$($g:tt)*], $col_type:ty) => {
impl<$($g)*> MulAssign<Scale<T>> for $col_type {
fn mul_assign(&mut self, rhs: Scale<T>) {
let ctx = self.context.clone();
let f = ctx.function::<T>("vec_mul_assign_scalar");
let nbatch = ctx.nbatch();
let nstates = self.len() as u32;
if nstates == 0 {
return;
}
let nbatch_u32 = nbatch as u32;
let stride_i32 = self.stride() as i32;
let scalar = rhs.value();
let mut build = ctx.stream.launch_builder(&f);
let mut data = self.kview_mut();
build
.arg(&mut data)
.arg(&scalar)
.arg(&nstates)
.arg(&nbatch_u32)
.arg(&stride_i32);
let config = ctx.launch_config_2d(nstates, nbatch_u32, &f);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
}
};
}
impl_mul_scalar!([T: ScalarCuda], CudaVec<T>, CudaVec<T>);
impl_mul_scalar_alloc!([T: ScalarCuda], &CudaVec<T>, CudaVec<T>);
impl_div_scalar!([T: ScalarCuda], CudaVec<T>, CudaVec<T>);
impl_mul_assign_scalar!([T: ScalarCuda], CudaVec<T>);
impl_mul_scalar_alloc!(['a, T: ScalarCuda], CudaVecRef<'a, T>, CudaVec<T>);
impl_mul_assign_scalar!(['a, T: ScalarCuda], CudaVecMut<'a, T>);
impl_mul_scalar_alloc!(['a, T: ScalarCuda], CudaVecMut<'a, T>, CudaVec<T>);
macro_rules! impl_assign {
([$($g:tt)*], $Op:ident, $method:ident, $kernel:expr, $label:expr,
$Lhs:ty, $RhsRef:ty, $RhsOwned:ty
) => {
impl<$($g)*> $Op<$RhsOwned> for $Lhs {
fn $method(&mut self, rhs: $RhsOwned) {
self.$method(&rhs);
}
}
impl<$($g)*> $Op<$RhsRef> for $Lhs {
fn $method(&mut self, rhs: $RhsRef) {
let ctx = self.context.clone();
let self_nbatch = ctx.nbatch();
let other_nbatch = rhs.context.nbatch();
ctx.assert_compatible_nbatch(other_nbatch, $label);
let nstates = self.len();
if nstates == 0 {
return;
}
let f = ctx.function::<T>($kernel);
let n_u32 = nstates as u32;
let nb_u32 = self_nbatch as u32;
let self_stride = self.stride() as i32;
let rhs_stride = rhs.stride() as i32;
let other_nbatch_i32 = other_nbatch as i32;
let config = ctx.launch_config_2d(n_u32, nb_u32, &f);
let mut build = ctx.stream.launch_builder(&f);
let mut self_data = self.kview_mut();
let rhs_data = rhs.kview();
build
.arg(&mut self_data)
.arg(&rhs_data)
.arg(&n_u32)
.arg(&self_stride)
.arg(&rhs_stride)
.arg(&other_nbatch_i32);
unsafe { build.launch(config) }.expect(concat!("Failed to launch ", $kernel));
}
}
};
}
macro_rules! impl_binary {
([$($g:tt)*], $Op:ident, $method:ident, $kernel:expr, $label:expr,
$Lhs:ty, $Rhs:ty
) => {
impl<$($g)*> $Op<$Rhs> for $Lhs {
type Output = CudaVec<T>;
fn $method(self, rhs: $Rhs) -> CudaVec<T> {
let ctx = self.context.clone();
let self_nbatch = ctx.nbatch();
let other_nbatch = rhs.context.nbatch();
ctx.assert_compatible_nbatch(other_nbatch, $label);
let nstates = self.len();
let mut ret = CudaVec::zeros(nstates, ctx.clone());
if nstates == 0 {
return ret;
}
let f = ctx.function::<T>($kernel);
let n_u32 = nstates as u32;
let nb_u32 = self_nbatch as u32;
let self_stride = self.stride() as i32;
let rhs_stride = rhs.stride() as i32;
let other_nbatch_i32 = other_nbatch as i32;
let nstates_i32 = nstates as i32;
let self_nbatch_i32 = self_nbatch as i32;
let config = ctx.launch_config_2d(n_u32, nb_u32, &f);
{
let mut build = ctx.stream.launch_builder(&f);
let self_data = self.kview();
let rhs_data = rhs.kview();
build
.arg(&self_data)
.arg(&rhs_data)
.arg(&mut ret.data)
.arg(&n_u32)
.arg(&self_stride)
.arg(&rhs_stride)
.arg(&other_nbatch_i32)
.arg(&nstates_i32)
.arg(&self_nbatch_i32);
unsafe { build.launch(config) }.expect(concat!("Failed to launch ", $kernel));
}
ret
}
}
};
}
impl_assign!(
[T: ScalarCuda],
SubAssign, sub_assign, "vec_sub_assign", "sub_assign",
CudaVec<T>, &CudaVec<T>, CudaVec<T>
);
impl_assign!(
['a, T: ScalarCuda],
SubAssign, sub_assign, "vec_sub_assign", "sub_assign",
CudaVec<T>, &CudaVecRef<'a, T>, CudaVecRef<'a, T>
);
impl_assign!(
['a, T: ScalarCuda],
SubAssign, sub_assign, "vec_sub_assign", "sub_assign",
CudaVecMut<'a, T>, &CudaVec<T>, CudaVec<T>
);
impl_assign!(
['a, 'b, T: ScalarCuda],
SubAssign, sub_assign, "vec_sub_assign", "sub_assign",
CudaVecMut<'a, T>, &CudaVecRef<'b, T>, CudaVecRef<'b, T>
);
impl_assign!(
[T: ScalarCuda],
AddAssign, add_assign, "vec_add_assign", "add_assign",
CudaVec<T>, &CudaVec<T>, CudaVec<T>
);
impl_assign!(
['a, T: ScalarCuda],
AddAssign, add_assign, "vec_add_assign", "add_assign",
CudaVec<T>, &CudaVecRef<'a, T>, CudaVecRef<'a, T>
);
impl_assign!(
['a, T: ScalarCuda],
AddAssign, add_assign, "vec_add_assign", "add_assign",
CudaVecMut<'a, T>, &CudaVec<T>, CudaVec<T>
);
impl_assign!(
['a, 'b, T: ScalarCuda],
AddAssign, add_assign, "vec_add_assign", "add_assign",
CudaVecMut<'a, T>, &CudaVecRef<'b, T>, CudaVecRef<'b, T>
);
impl_binary!([T: ScalarCuda], Sub, sub, "vec_sub", "sub", &CudaVec<T>, &CudaVec<T>);
impl_binary!(
['a, T: ScalarCuda], Sub, sub, "vec_sub", "sub", &CudaVec<T>, &CudaVecRef<'a, T>
);
impl_binary!(
['a, T: ScalarCuda], Sub, sub, "vec_sub", "sub", &CudaVecRef<'a, T>, &CudaVec<T>
);
impl_binary!(
['a, 'b, T: ScalarCuda], Sub, sub, "vec_sub", "sub",
&CudaVecRef<'a, T>, &CudaVecRef<'b, T>
);
impl<T: ScalarCuda> Sub<CudaVec<T>> for CudaVec<T> {
type Output = CudaVec<T>;
fn sub(mut self, rhs: CudaVec<T>) -> CudaVec<T> {
self.sub_assign(&rhs);
self
}
}
impl<T: ScalarCuda> Sub<&CudaVec<T>> for CudaVec<T> {
type Output = CudaVec<T>;
fn sub(mut self, rhs: &CudaVec<T>) -> CudaVec<T> {
self.sub_assign(rhs);
self
}
}
impl<T: ScalarCuda> Sub<CudaVecRef<'_, T>> for CudaVec<T> {
type Output = CudaVec<T>;
fn sub(mut self, rhs: CudaVecRef<'_, T>) -> CudaVec<T> {
self.sub_assign(&rhs);
self
}
}
impl<T: ScalarCuda> Sub<&CudaVecRef<'_, T>> for CudaVec<T> {
type Output = CudaVec<T>;
fn sub(mut self, rhs: &CudaVecRef<'_, T>) -> CudaVec<T> {
self.sub_assign(rhs);
self
}
}
impl<T: ScalarCuda> Sub<CudaVec<T>> for &CudaVec<T> {
type Output = CudaVec<T>;
fn sub(self, rhs: CudaVec<T>) -> CudaVec<T> {
self.sub(&rhs)
}
}
impl<T: ScalarCuda> Sub<CudaVecRef<'_, T>> for &CudaVec<T> {
type Output = CudaVec<T>;
fn sub(self, rhs: CudaVecRef<'_, T>) -> CudaVec<T> {
self.sub(&rhs)
}
}
impl<T: ScalarCuda> Sub<CudaVec<T>> for CudaVecRef<'_, T> {
type Output = CudaVec<T>;
fn sub(self, rhs: CudaVec<T>) -> CudaVec<T> {
&self - &rhs
}
}
impl<T: ScalarCuda> Sub<CudaVecRef<'_, T>> for CudaVecRef<'_, T> {
type Output = CudaVec<T>;
fn sub(self, rhs: CudaVecRef<'_, T>) -> CudaVec<T> {
&self - &rhs
}
}
impl<'a, T: ScalarCuda> Sub<&CudaVec<T>> for CudaVecRef<'a, T> {
type Output = CudaVec<T>;
fn sub(self, rhs: &CudaVec<T>) -> CudaVec<T> {
&self - rhs
}
}
impl<'a, 'b, T: ScalarCuda> Sub<&CudaVecRef<'b, T>> for CudaVecRef<'a, T> {
type Output = CudaVec<T>;
fn sub(self, rhs: &CudaVecRef<'b, T>) -> CudaVec<T> {
&self - rhs
}
}
impl_binary!([T: ScalarCuda], Add, add, "vec_add", "add", &CudaVec<T>, &CudaVec<T>);
impl_binary!(
['a, T: ScalarCuda], Add, add, "vec_add", "add", &CudaVec<T>, &CudaVecRef<'a, T>
);
impl_binary!(
['a, T: ScalarCuda], Add, add, "vec_add", "add", &CudaVecRef<'a, T>, &CudaVec<T>
);
impl_binary!(
['a, 'b, T: ScalarCuda], Add, add, "vec_add", "add",
&CudaVecRef<'a, T>, &CudaVecRef<'b, T>
);
impl<T: ScalarCuda> Add<CudaVec<T>> for CudaVec<T> {
type Output = CudaVec<T>;
fn add(mut self, rhs: CudaVec<T>) -> CudaVec<T> {
self.add_assign(&rhs);
self
}
}
impl<T: ScalarCuda> Add<&CudaVec<T>> for CudaVec<T> {
type Output = CudaVec<T>;
fn add(mut self, rhs: &CudaVec<T>) -> CudaVec<T> {
self.add_assign(rhs);
self
}
}
impl<T: ScalarCuda> Add<CudaVecRef<'_, T>> for CudaVec<T> {
type Output = CudaVec<T>;
fn add(mut self, rhs: CudaVecRef<'_, T>) -> CudaVec<T> {
self.add_assign(&rhs);
self
}
}
impl<T: ScalarCuda> Add<&CudaVecRef<'_, T>> for CudaVec<T> {
type Output = CudaVec<T>;
fn add(mut self, rhs: &CudaVecRef<'_, T>) -> CudaVec<T> {
self.add_assign(rhs);
self
}
}
impl<T: ScalarCuda> Add<CudaVec<T>> for &CudaVec<T> {
type Output = CudaVec<T>;
fn add(self, rhs: CudaVec<T>) -> CudaVec<T> {
self.add(&rhs)
}
}
impl<T: ScalarCuda> Add<CudaVecRef<'_, T>> for &CudaVec<T> {
type Output = CudaVec<T>;
fn add(self, rhs: CudaVecRef<'_, T>) -> CudaVec<T> {
self.add(&rhs)
}
}
impl<T: ScalarCuda> Add<CudaVec<T>> for CudaVecRef<'_, T> {
type Output = CudaVec<T>;
fn add(self, rhs: CudaVec<T>) -> CudaVec<T> {
&self + &rhs
}
}
impl<T: ScalarCuda> Add<CudaVecRef<'_, T>> for CudaVecRef<'_, T> {
type Output = CudaVec<T>;
fn add(self, rhs: CudaVecRef<'_, T>) -> CudaVec<T> {
&self + &rhs
}
}
impl<'a, T: ScalarCuda> Add<&CudaVec<T>> for CudaVecRef<'a, T> {
type Output = CudaVec<T>;
fn add(self, rhs: &CudaVec<T>) -> CudaVec<T> {
&self + rhs
}
}
impl<'a, 'b, T: ScalarCuda> Add<&CudaVecRef<'b, T>> for CudaVecRef<'a, T> {
type Output = CudaVec<T>;
fn add(self, rhs: &CudaVecRef<'b, T>) -> CudaVec<T> {
&self + rhs
}
}
impl VectorIndex for CudaIndex {
type C = CudaContext;
fn context(&self) -> &Self::C {
&self.context
}
fn len(&self) -> IndexType {
self.data.len() as IndexType
}
fn zeros(len: IndexType, ctx: Self::C) -> Self {
let data = ctx
.stream
.alloc_zeros(len)
.expect("Failed to allocate memory for CudaVec");
Self { data, context: ctx }
}
fn clone_as_vec(&self) -> Vec<IndexType> {
self.context
.stream
.clone_dtoh(&self.data)
.expect("Failed to copy data from device to host")
.into_iter()
.map(|x| x as IndexType)
.collect()
}
fn from_vec(v: Vec<IndexType>, ctx: Self::C) -> Self {
let mut data = unsafe {
ctx.stream
.alloc(v.len())
.expect("Failed to allocate memory for CudaVec")
};
let v = v.into_iter().map(|x| x as c_int).collect::<Vec<_>>();
ctx.stream
.memcpy_htod(&v, &mut data)
.expect("Failed to copy data from host to device");
Self { data, context: ctx }
}
}
impl<T: ScalarCuda> Vector for CudaVec<T> {
type View<'a> = CudaVecRef<'a, T>;
type ViewMut<'a> = CudaVecMut<'a, T>;
type Index = CudaIndex;
fn context(&self) -> &Self::C {
&self.context
}
fn inner_mut(&mut self) -> &mut Self::Inner {
&mut self.data
}
fn get_index(&self, index: IndexType) -> Self::T {
let nbatch = self.context.nbatch();
if nbatch > 1 {
panic!("get_index not supported for batched vectors");
}
self.context
.stream
.clone_dtoh(&self.data.slice(index..index + 1))
.expect("Failed to copy data from device to host")[0]
}
fn set_index(&mut self, index: IndexType, value: Self::T) {
let nbatch = self.context.nbatch();
let nstates = self.len();
assert!(index < nstates, "Index out of bounds");
let data = vec![value];
for b in 0..nbatch {
let idx = b * nstates + index;
self.context
.stream
.memcpy_htod(&data, &mut self.data.slice_mut(idx..idx + 1))
.expect("Failed to copy data from host to device");
}
}
fn norm(&self, k: i32) -> Self::T {
let nbatch = self.context.nbatch();
let nstates = self.len();
if nstates == 0 {
return T::zero();
}
if k == 2 && nbatch == 1 {
let blas =
CudaBlas::new(self.context.stream.clone()).expect("Failed to create CudaBlas");
let n = self.data.len() as c_int;
let (x, _) = self.data.device_ptr(&self.context.stream);
let result: T = match T::as_enum() {
CudaType::F64 => {
let x = x as *const f64;
let mut result_f64 = 0.0;
unsafe {
cublas::cublasDnrm2_v2(*blas.handle(), n, x, 1, &mut result_f64 as *mut f64)
}
.result()
.expect("Failed to call cublasDnrm2_v2");
T::from_f64(result_f64).unwrap()
}
};
return result;
}
self.context.lk_norm(&self.data, nstates, nbatch, k)
}
fn squared_norm(&self, y: &Self, atol: &Self, rtol: Self::T) -> Self::T {
self.as_view().squared_norm(y, atol, rtol)
}
fn len(&self) -> IndexType {
self.data.len() as IndexType / self.context.nbatch()
}
fn from_vec(v: Vec<Self::T>, ctx: Self::C) -> Self {
let nbatch = ctx.nbatch();
assert!(
v.len() % nbatch == 0,
"vector length {} must be divisible by nbatch {}",
v.len(),
nbatch
);
let mut data = unsafe {
ctx.stream
.alloc(v.len())
.expect("Failed to allocate memory for CudaVec")
};
ctx.stream
.memcpy_htod(&v, &mut data)
.expect("Failed to copy data from host to device");
Self { data, context: ctx }
}
fn from_slice(slice: &[Self::T], ctx: Self::C) -> Self {
let nbatch = ctx.nbatch();
assert!(
slice.len() % nbatch == 0,
"slice length {} must be divisible by nbatch {}",
slice.len(),
nbatch
);
let mut data = unsafe {
ctx.stream
.alloc(slice.len())
.expect("Failed to allocate memory for CudaVec")
};
ctx.stream
.memcpy_htod(slice, &mut data)
.expect("Failed to copy data from host to device");
Self { data, context: ctx }
}
fn from_element(nstates: usize, value: Self::T, ctx: Self::C) -> Self {
let nbatch = ctx.nbatch();
let total = nstates * nbatch;
let data = unsafe {
ctx.stream
.alloc(total)
.expect("Failed to allocate memory for CudaVec")
};
let mut ret = Self { data, context: ctx };
ret.fill(value);
ret
}
fn zeros(nstates: usize, ctx: Self::C) -> Self {
let nbatch = ctx.nbatch();
let total = nstates * nbatch;
let data = ctx
.stream
.alloc_zeros(total)
.expect("Failed to allocate memory for CudaVec");
Self { data, context: ctx }
}
fn fill(&mut self, value: Self::T) {
let nbatch = self.context.nbatch();
let nstates = self.len();
let nstates_u32 = nstates as u32;
let nbatch_u32 = nbatch as u32;
if nstates_u32 == 0 {
return;
}
let f = self.context.function::<T>("vec_fill");
let config = self.context.launch_config_2d(nstates_u32, nbatch_u32, &f);
let mut build = self.context.stream.launch_builder(&f);
let nstates_i32 = nstates as i32;
build
.arg(&mut self.data)
.arg(&value)
.arg(&nstates_u32)
.arg(&nbatch_u32)
.arg(&nstates_i32);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
fn as_view(&self) -> Self::View<'_> {
let nstates = self.len();
CudaVecRef {
data: self.data.as_view(),
context: self.context.clone(),
nstates,
col_offset: 0,
}
}
fn as_view_mut(&mut self) -> Self::ViewMut<'_> {
let nstates = self.len();
CudaVecMut {
data: self.data.as_view_mut(),
context: self.context.clone(),
nstates,
col_offset: 0,
}
}
fn copy_from(&mut self, other: &Self) {
self.copy_from_view(&other.as_view());
}
fn copy_from_view(&mut self, other: &Self::View<'_>) {
let self_nbatch = self.context.nbatch();
let other_nbatch = other.context.nbatch();
self.context
.assert_compatible_nbatch(other_nbatch, "copy_from_view");
let nstates = self.len() as u32;
if nstates == 0 {
return;
}
let nbatch = self_nbatch as u32;
let f = self.context.function::<T>("vec_copy");
let config = self.context.launch_config_2d(nstates, nbatch, &f);
let mut build = self.context.stream.launch_builder(&f);
let self_stride = nstates as i32;
let rhs_stride = other.stride() as i32;
let rhs_data = other.data.slice(other.col_offset..);
let other_nbatch_i32 = other_nbatch as i32;
build
.arg(&mut self.data)
.arg(&rhs_data)
.arg(&nstates)
.arg(&self_stride)
.arg(&rhs_stride)
.arg(&other_nbatch_i32);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
fn axpy(&mut self, alpha: Self::T, x: &Self, beta: Self::T) {
self.axpy_v(alpha, &x.as_view(), beta);
}
fn axpy_v(&mut self, alpha: Self::T, x: &Self::View<'_>, beta: Self::T) {
let self_nbatch = self.context.nbatch();
let x_nbatch = x.context.nbatch();
self.context.assert_compatible_nbatch(x_nbatch, "axpy_v");
let nstates = self.len();
if nstates == 0 {
return;
}
let nstates_u32 = nstates as u32;
let nbatch_u32 = self_nbatch as u32;
let f = self.context.function::<T>("vec_axpy");
let config = self.context.launch_config_2d(nstates_u32, nbatch_u32, &f);
let mut build = self.context.stream.launch_builder(&f);
let self_stride = nstates as i32;
let x_stride = x.stride() as i32;
let x_nbatch_i32 = x_nbatch as i32;
let x_data = x.data.slice(x.col_offset..);
let alpha_val = alpha;
let beta_val = beta;
build
.arg(&mut self.data)
.arg(&x_data)
.arg(&alpha_val)
.arg(&beta_val)
.arg(&nstates_u32)
.arg(&self_stride)
.arg(&x_stride)
.arg(&x_nbatch_i32);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
fn batched_axpy(&mut self, alpha: &[Self::T], x: &Self, beta: Self::T) {
let self_nbatch = self.context.nbatch();
let x_nbatch = x.context.nbatch();
assert_eq!(
alpha.len(),
self_nbatch,
"batched_axpy: alpha.len() must equal self.nbatch()"
);
self.context
.assert_compatible_nbatch(x_nbatch, "batched_axpy");
let nstates = self.len();
if nstates == 0 {
return;
}
let x_nstates = x.len();
let mut alpha_dev = unsafe {
self.context
.stream
.alloc::<T>(self_nbatch)
.expect("Failed to allocate device memory for batched_axpy alpha")
};
self.context
.stream
.memcpy_htod(alpha, &mut alpha_dev)
.expect("Failed to copy alpha to device");
let nstates_u32 = nstates as u32;
let nbatch_u32 = self_nbatch as u32;
let f = self.context.function::<T>("vec_batched_axpy");
let config = self.context.launch_config_2d(nstates_u32, nbatch_u32, &f);
let mut build = self.context.stream.launch_builder(&f);
let self_stride = nstates as i32;
let x_stride = x_nstates as i32;
let x_nbatch_i32 = x_nbatch as i32;
build
.arg(&mut self.data)
.arg(&x.data)
.arg(&alpha_dev)
.arg(&beta)
.arg(&nstates_u32)
.arg(&self_stride)
.arg(&x_stride)
.arg(&x_nbatch_i32);
unsafe { build.launch(config) }.expect("Failed to launch batched_axpy kernel");
}
fn clone_as_vec(&self) -> Vec<Self::T> {
self.context
.stream
.clone_dtoh(&self.data)
.expect("Failed to copy data from device to host")
}
fn component_mul_assign(&mut self, other: &Self) {
let self_nbatch = self.context.nbatch();
let other_nbatch = other.context.nbatch();
self.context
.assert_compatible_nbatch(other_nbatch, "component_mul_assign");
let nstates = self.len() as u32;
let other_nstates = other.len();
if nstates == 0 {
return;
}
let nbatch = self_nbatch as u32;
let f = self.context.function::<T>("vec_mul_assign");
let config = self.context.launch_config_2d(nstates, nbatch, &f);
let mut build = self.context.stream.launch_builder(&f);
let self_stride = nstates as i32;
let rhs_stride = other_nstates as i32;
let other_nbatch_i32 = other_nbatch as i32;
build
.arg(&mut self.data)
.arg(&other.data)
.arg(&nstates)
.arg(&self_stride)
.arg(&rhs_stride)
.arg(&other_nbatch_i32);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
fn component_div_assign(&mut self, other: &Self) {
let self_nbatch = self.context.nbatch();
let other_nbatch = other.context.nbatch();
self.context
.assert_compatible_nbatch(other_nbatch, "component_div_assign");
let nstates = self.len() as u32;
let other_nstates = other.len();
if nstates == 0 {
return;
}
let nbatch = self_nbatch as u32;
let f = self.context.function::<T>("vec_div_assign");
let config = self.context.launch_config_2d(nstates, nbatch, &f);
let mut build = self.context.stream.launch_builder(&f);
let self_stride = nstates as i32;
let rhs_stride = other_nstates as i32;
let other_nbatch_i32 = other_nbatch as i32;
build
.arg(&mut self.data)
.arg(&other.data)
.arg(&nstates)
.arg(&self_stride)
.arg(&rhs_stride)
.arg(&other_nbatch_i32);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
fn root_finding(&self, g1: &Self) -> (bool, Self::T, i32) {
let nbatch = self.context.nbatch();
let nstates = self.len();
if nstates == 0 {
return (false, Self::T::zero(), -1);
}
let g1_nbatch = g1.context.nbatch();
let g1_nstates = g1.len();
assert_eq!(
nstates, g1_nstates,
"Vector length mismatch: {} != {}",
nstates, g1_nstates
);
let nstates_u32 = nstates as u32;
let nbatch_u32 = nbatch as u32;
let f = self.context.function::<T>("vec_root_finding");
let config = self.context.launch_config_2d_reduce(
nstates_u32,
nbatch_u32,
&f,
root_finding_blk_size::<T>,
);
let blocks_per_batch = config.grid_dim.0 as usize;
let total_blocks = blocks_per_batch * nbatch;
let mut max_vals = unsafe {
self.context
.stream
.alloc::<T>(total_blocks)
.expect("Failed to allocate memory for max_vals")
};
let mut max_idxs = unsafe {
self.context
.stream
.alloc::<c_int>(total_blocks)
.expect("Failed to allocate memory for max_idxs")
};
let mut root_flag = self
.context
.stream
.alloc_zeros::<c_int>(nbatch)
.expect("Failed to allocate memory for root_flag");
let mut build = self.context.stream.launch_builder(&f);
let g0_stride = nstates as i32;
let g0_nbatch_i32 = nbatch as i32;
let g1_stride = g1_nstates as i32;
let g1_nbatch_i32 = g1_nbatch as i32;
build
.arg(&self.data)
.arg(&g1.data)
.arg(&nstates_u32)
.arg(&nbatch_u32)
.arg(&g0_stride)
.arg(&g0_nbatch_i32)
.arg(&g1_stride)
.arg(&g1_nbatch_i32)
.arg(&mut root_flag)
.arg(&mut max_vals)
.arg(&mut max_idxs);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
let h_max_vals = self
.context
.stream
.clone_dtoh(&max_vals)
.expect("Failed to copy data from device to host");
let h_max_idxs = self
.context
.stream
.clone_dtoh(&max_idxs)
.expect("Failed to copy data from device to host");
let h_root_flag = self
.context
.stream
.clone_dtoh(&root_flag)
.expect("Failed to copy data from device to host");
let mut first_result: Option<(bool, T, i32)> = None;
for (b, &flag) in h_root_flag.iter().enumerate().take(nbatch) {
let found_root = flag != 0;
let start = b * blocks_per_batch;
let mut max_val = T::zero();
let mut max_idx = -1;
for i in start..start + blocks_per_batch {
if h_max_vals[i] > max_val {
max_val = h_max_vals[i];
max_idx = h_max_idxs[i];
}
}
let result = (found_root, max_val, max_idx);
if let Some(ref first) = first_result {
if first.0 != result.0 || first.2 != result.2 {
panic!(
"Root finding results differ across batches: batch 0 = {:?}, batch {} = {:?}",
first, b, result
);
}
} else {
first_result = Some(result);
}
}
first_result.unwrap()
}
fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) {
let self_nbatch = self.context.nbatch();
let nindices_u32 = indices.len() as u32;
if nindices_u32 == 0 {
return;
}
let nstates = self.len();
let nbatch_u32 = self_nbatch as u32;
let f = self.context.function::<T>("vec_assign_at_indices");
let config = self.context.launch_config_2d(nindices_u32, nbatch_u32, &f);
let mut build = self.context.stream.launch_builder(&f);
let self_stride = nstates as i32;
let self_nbatch_i32 = self_nbatch as i32;
build
.arg(&mut self.data)
.arg(&indices.data)
.arg(&value)
.arg(&nindices_u32)
.arg(&self_stride)
.arg(&self_nbatch_i32);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
fn copy_from_indices(&mut self, other: &Self, indices: &Self::Index) {
let self_nbatch = self.context.nbatch();
let other_nbatch = other.context.nbatch();
let nindices_u32 = indices.len() as u32;
if nindices_u32 == 0 {
return;
}
let nstates = self.len();
let nbatch_u32 = self_nbatch as u32;
let f = self.context.function::<T>("vec_copy_from_indices");
let config = self.context.launch_config_2d(nindices_u32, nbatch_u32, &f);
let mut build = self.context.stream.launch_builder(&f);
let self_stride = nstates as i32;
let self_nbatch_i32 = self_nbatch as i32;
let other_nstates = other.len();
let other_stride = other_nstates as i32;
let other_nbatch_i32 = other_nbatch as i32;
build
.arg(&mut self.data)
.arg(&other.data)
.arg(&indices.data)
.arg(&nindices_u32)
.arg(&self_stride)
.arg(&self_nbatch_i32)
.arg(&other_stride)
.arg(&other_nbatch_i32);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
fn gather(&mut self, other: &Self, indices: &Self::Index) {
let self_nbatch = self.context.nbatch();
let other_nbatch = other.context.nbatch();
let nindices_u32 = indices.len() as u32;
if nindices_u32 == 0 {
return;
}
let nstates = self.len();
let nbatch_u32 = self_nbatch as u32;
let f = self.context.function::<T>("vec_gather");
let config = self.context.launch_config_2d(nindices_u32, nbatch_u32, &f);
let mut build = self.context.stream.launch_builder(&f);
let self_stride = nstates as i32;
let self_nbatch_i32 = self_nbatch as i32;
let other_nstates = other.len();
let other_stride = other_nstates as i32;
let other_nbatch_i32 = other_nbatch as i32;
build
.arg(&mut self.data)
.arg(&other.data)
.arg(&indices.data)
.arg(&nindices_u32)
.arg(&self_stride)
.arg(&self_nbatch_i32)
.arg(&other_stride)
.arg(&other_nbatch_i32);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
fn scatter(&self, indices: &Self::Index, other: &mut Self) {
let self_nbatch = self.context.nbatch();
let other_nbatch = other.context.nbatch();
let nindices_u32 = indices.len() as u32;
if nindices_u32 == 0 {
return;
}
let nstates = self.len();
let nbatch_u32 = self_nbatch as u32;
let f = self.context.function::<T>("vec_scatter");
let config = self.context.launch_config_2d(nindices_u32, nbatch_u32, &f);
let mut build = self.context.stream.launch_builder(&f);
let self_stride = nstates as i32;
let self_nbatch_i32 = self_nbatch as i32;
let other_nstates = other.len();
let other_stride = other_nstates as i32;
let other_nbatch_i32 = other_nbatch as i32;
build
.arg(&self.data)
.arg(&indices.data)
.arg(&mut other.data)
.arg(&nindices_u32)
.arg(&self_stride)
.arg(&self_nbatch_i32)
.arg(&other_stride)
.arg(&other_nbatch_i32);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
fn get_batch(&self, batch: usize) -> Self::View<'_> {
let nbatch = self.context.nbatch();
let nstates = self.len();
assert!(batch < nbatch, "Batch index out of bounds");
let start = batch * nstates;
CudaVecRef {
data: self.data.slice(start..start + nstates),
context: self.context.clone_with_nbatch(1).unwrap(),
nstates,
col_offset: 0,
}
}
fn get_batch_mut(&mut self, batch: usize) -> Self::ViewMut<'_> {
let nbatch = self.context.nbatch();
let nstates = self.len();
assert!(batch < nbatch, "Batch index out of bounds");
let start = batch * nstates;
CudaVecMut {
data: self.data.slice_mut(start..start + nstates),
context: self.context.clone_with_nbatch(1).unwrap(),
nstates,
col_offset: 0,
}
}
}
impl<T: ScalarCuda> VectorView<'_> for CudaVecRef<'_, T> {
type Owned = CudaVec<T>;
fn get_index(&self, index: IndexType) -> Self::T {
let nbatch = self.context.nbatch();
if nbatch > 1 {
panic!("get_index not supported for batched views");
}
let offset = self.col_offset + index;
self.context
.stream
.clone_dtoh(&self.data.slice(offset..offset + 1))
.expect("Failed to copy data from device to host")[0]
}
fn into_owned(self) -> Self::Owned {
let nbatch = self.context.nbatch();
let stride = self.stride();
let total_valid = self.nstates * nbatch;
if stride == self.nstates && self.col_offset == 0 {
let mut ret = unsafe { self.context.stream.alloc(self.data.len()) }
.expect("Failed to allocate memory for CudaVec");
self.context
.stream
.memcpy_dtod(&self.data, &mut ret)
.expect("Failed to copy data from device to device");
Self::Owned {
data: ret,
context: self.context,
}
} else {
let mut ret = unsafe {
self.context
.stream
.alloc(total_valid)
.expect("Failed to allocate memory for CudaVec")
};
for b in 0..nbatch {
let src_start = b * stride + self.col_offset;
let dst_start = b * self.nstates;
let src_slice = self.data.slice(src_start..src_start + self.nstates);
let mut dst_slice = ret.slice_mut(dst_start..dst_start + self.nstates);
self.context
.stream
.memcpy_dtod(&src_slice, &mut dst_slice)
.expect("Failed to copy data from device to device");
}
Self::Owned {
data: ret,
context: self.context,
}
}
}
fn squared_norm(&self, y: &Self::Owned, atol: &Self::Owned, rtol: Self::T) -> Self::T {
let nbatch = self.context.nbatch();
let nstates = self.nstates;
if nstates == 0 {
return Self::T::zero();
}
let atol_nbatch = atol.context.nbatch();
let y_nbatch = y.context.nbatch();
let nstates_u32 = nstates as u32;
let nbatch_u32 = nbatch as u32;
let f = self.context.function::<T>("vec_squared_norm");
let config = self.context.launch_config_2d_reduce(
nstates_u32,
nbatch_u32,
&f,
squared_norm_blk_size::<T>,
);
let blocks_per_batch = config.grid_dim.0 as usize;
let total_blocks = blocks_per_batch * nbatch;
let mut partial_sums = unsafe {
self.context
.stream
.alloc::<T>(total_blocks)
.expect("Failed to allocate memory for partial sums")
};
let mut build = self.context.stream.launch_builder(&f);
let self_data = self.data.slice(self.col_offset..);
let y_stride = self.stride() as i32;
let y_nbatch_i32 = nbatch as i32;
let y0_stride = y.len() as i32;
let y0_nbatch_i32 = y_nbatch as i32;
let atol_stride = atol.len() as i32;
let atol_nbatch_i32 = atol_nbatch as i32;
let rtol_val = rtol;
build
.arg(&self_data)
.arg(&y.data)
.arg(&atol.data)
.arg(&rtol_val)
.arg(&nstates_u32)
.arg(&nbatch_u32)
.arg(&y_stride)
.arg(&y_nbatch_i32)
.arg(&y0_stride)
.arg(&y0_nbatch_i32)
.arg(&atol_stride)
.arg(&atol_nbatch_i32)
.arg(&mut partial_sums);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
let partial_sums = self
.context
.stream
.clone_dtoh(&partial_sums)
.expect("Failed to copy data from device to host");
let nstates_t = T::from_f64(nstates as f64).unwrap();
let mut max_norm = T::zero();
for b in 0..nbatch {
let start = b * blocks_per_batch;
let sum = partial_sums[start..start + blocks_per_batch]
.iter()
.fold(T::zero(), |acc, x| acc + *x);
let norm = sum / nstates_t;
if norm > max_norm {
max_norm = norm;
}
}
max_norm
}
}
impl<'a, T: ScalarCuda> VectorViewMut<'a> for CudaVecMut<'a, T> {
type Owned = CudaVec<T>;
type View = CudaVecRef<'a, T>;
type Index = CudaIndex;
fn copy_from(&mut self, other: &Self::Owned) {
let v = other.as_view();
self.copy_from_ref(&v);
}
fn copy_from_view(&mut self, other: &Self::View) {
self.copy_from_ref(other);
}
fn set_index(&mut self, index: IndexType, value: Self::T) {
let nbatch = self.context.nbatch();
assert!(index < self.nstates, "Index out of bounds");
let data = vec![value];
for b in 0..nbatch {
let idx = b * self.stride() + self.col_offset + index;
self.context
.stream
.memcpy_htod(&data, &mut self.data.slice_mut(idx..idx + 1))
.expect("Failed to copy data from host to device");
}
}
fn axpy(&mut self, alpha: Self::T, x: &Self::Owned, beta: Self::T) {
let nbatch = self.context.nbatch();
let x_nbatch = x.context.nbatch();
self.context.assert_compatible_nbatch(x_nbatch, "axpy");
let nstates_u32 = self.nstates as u32;
if nstates_u32 == 0 {
return;
}
let nbatch_u32 = nbatch as u32;
let f = self.context.function::<T>("vec_axpy");
let config = self.context.launch_config_2d(nstates_u32, nbatch_u32, &f);
let mut build = self.context.stream.launch_builder(&f);
let self_stride = self.stride() as i32;
let col_offset = self.col_offset;
let x_nstates = x.len();
let x_stride = x_nstates as i32;
let x_nbatch_i32 = x_nbatch as i32;
let mut self_data = self.data.slice_mut(col_offset..);
let alpha_val = alpha;
let beta_val = beta;
build
.arg(&mut self_data)
.arg(&x.data)
.arg(&alpha_val)
.arg(&beta_val)
.arg(&nstates_u32)
.arg(&self_stride)
.arg(&x_stride)
.arg(&x_nbatch_i32);
unsafe { build.launch(config) }.expect("Failed to launch kernel");
}
}
#[cfg(test)]
mod tests {
use super::*;
super::super::generate_vector_tests_nonbatched!(cuda, CudaVec<f64>);
super::super::generate_vector_tests_batched!(
cuda,
CudaVec<f64>,
CudaContext::default().with_nbatch(2),
CudaContext::default().with_nbatch(3)
);
}