use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::{
BLOCK_SIZE, SPARSE_LINALG_MODULE, get_kernel_function, get_or_load_module, grid_size,
launch_config, launch_error,
};
use crate::error::Result;
pub unsafe fn launch_find_diag_indices(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
row_ptrs: u64,
col_indices: u64,
diag_indices: u64,
n: i32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "find_diag_indices")?;
let cfg = launch_config((grid_size(n as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&row_ptrs);
builder.arg(&col_indices);
builder.arg(&diag_indices);
builder.arg(&n);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("find_diag_indices", e))?;
Ok(())
}
pub unsafe fn launch_find_diag_indices_csc(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
col_ptrs: u64,
row_indices: u64,
diag_ptr: u64,
n: i32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "find_diag_indices_csc")?;
let cfg = launch_config((grid_size(n as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&col_ptrs);
builder.arg(&row_indices);
builder.arg(&diag_ptr);
builder.arg(&n);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("find_diag_indices_csc", e))?;
Ok(())
}
#[allow(dead_code)]
pub unsafe fn launch_copy_f32(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
src: u64,
dst: u64,
n: i32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "copy_f32")?;
let cfg = launch_config((grid_size(n as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&src);
builder.arg(&dst);
builder.arg(&n);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("copy_f32", e))?;
Ok(())
}
#[allow(dead_code)]
pub unsafe fn launch_copy_f64(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
src: u64,
dst: u64,
n: i32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "copy_f64")?;
let cfg = launch_config((grid_size(n as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&src);
builder.arg(&dst);
builder.arg(&n);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("copy_f64", e))?;
Ok(())
}
pub unsafe fn launch_split_lu_scatter_f32(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
src_values: u64,
l_values: u64,
u_values: u64,
l_map: u64,
u_map: u64,
nnz: i32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "split_lu_scatter_f32")?;
let cfg = launch_config((grid_size(nnz as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&src_values);
builder.arg(&l_values);
builder.arg(&u_values);
builder.arg(&l_map);
builder.arg(&u_map);
builder.arg(&nnz);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("split_lu_scatter_f32", e))?;
Ok(())
}
pub unsafe fn launch_split_lu_scatter_f64(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
src_values: u64,
l_values: u64,
u_values: u64,
l_map: u64,
u_map: u64,
nnz: i32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "split_lu_scatter_f64")?;
let cfg = launch_config((grid_size(nnz as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&src_values);
builder.arg(&l_values);
builder.arg(&u_values);
builder.arg(&l_map);
builder.arg(&u_map);
builder.arg(&nnz);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("split_lu_scatter_f64", e))?;
Ok(())
}
pub unsafe fn launch_extract_lower_scatter_f32(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
src_values: u64,
dst_values: u64,
lower_map: u64,
nnz: i32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "extract_lower_scatter_f32")?;
let cfg = launch_config((grid_size(nnz as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&src_values);
builder.arg(&dst_values);
builder.arg(&lower_map);
builder.arg(&nnz);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("extract_lower_scatter_f32", e))?;
Ok(())
}
pub unsafe fn launch_extract_lower_scatter_f64(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
src_values: u64,
dst_values: u64,
lower_map: u64,
nnz: i32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "extract_lower_scatter_f64")?;
let cfg = launch_config((grid_size(nnz as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&src_values);
builder.arg(&dst_values);
builder.arg(&lower_map);
builder.arg(&nnz);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("extract_lower_scatter_f64", e))?;
Ok(())
}