use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::{
BLOCK_SIZE, SPARSE_LINALG_MODULE, get_kernel_function, get_or_load_module, grid_size,
launch_config, launch_error,
};
use crate::error::Result;
pub unsafe fn launch_sparse_qr_apply_reflector_f32(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
v: u64,
v_start: i32,
v_len: i32,
tau_ptr: u64,
work: u64,
m: i32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "sparse_qr_apply_reflector_f32")?;
let cfg = launch_config((1, 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&v);
builder.arg(&v_start);
builder.arg(&v_len);
builder.arg(&tau_ptr);
builder.arg(&work);
builder.arg(&m);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_apply_reflector_f32", e))?;
Ok(())
}
pub unsafe fn launch_sparse_qr_apply_reflector_f64(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
v: u64,
v_start: i32,
v_len: i32,
tau_ptr: u64,
work: u64,
m: i32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "sparse_qr_apply_reflector_f64")?;
let cfg = launch_config((1, 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&v);
builder.arg(&v_start);
builder.arg(&v_len);
builder.arg(&tau_ptr);
builder.arg(&work);
builder.arg(&m);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_apply_reflector_f64", e))?;
Ok(())
}
pub unsafe fn launch_sparse_qr_norm_f32(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
work: u64,
start: i32,
count: i32,
result: u64,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "sparse_qr_norm_f32")?;
let cfg = launch_config((1, 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&work);
builder.arg(&start);
builder.arg(&count);
builder.arg(&result);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_norm_f32", e))?;
Ok(())
}
pub unsafe fn launch_sparse_qr_norm_f64(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
work: u64,
start: i32,
count: i32,
result: u64,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "sparse_qr_norm_f64")?;
let cfg = launch_config((1, 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&work);
builder.arg(&start);
builder.arg(&count);
builder.arg(&result);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_norm_f64", e))?;
Ok(())
}
pub unsafe fn launch_sparse_qr_householder_f32(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
work: u64,
start: i32,
m: i32,
norm_sq_ptr: u64,
out_v: u64,
out_tau: u64,
out_diag: u64,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "sparse_qr_householder_f32")?;
let cfg = launch_config((1, 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&work);
builder.arg(&start);
builder.arg(&m);
builder.arg(&norm_sq_ptr);
builder.arg(&out_v);
builder.arg(&out_tau);
builder.arg(&out_diag);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_householder_f32", e))?;
Ok(())
}
pub unsafe fn launch_sparse_qr_householder_f64(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
work: u64,
start: i32,
m: i32,
norm_sq_ptr: u64,
out_v: u64,
out_tau: u64,
out_diag: u64,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "sparse_qr_householder_f64")?;
let cfg = launch_config((1, 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&work);
builder.arg(&start);
builder.arg(&m);
builder.arg(&norm_sq_ptr);
builder.arg(&out_v);
builder.arg(&out_tau);
builder.arg(&out_diag);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_householder_f64", e))?;
Ok(())
}
pub unsafe fn launch_sparse_qr_extract_r_f32(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
work: u64,
count: i32,
output: u64,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "sparse_qr_extract_r_f32")?;
let cfg = launch_config((grid_size(count as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&work);
builder.arg(&count);
builder.arg(&output);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_extract_r_f32", e))?;
Ok(())
}
pub unsafe fn launch_sparse_qr_extract_r_f64(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
work: u64,
count: i32,
output: u64,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "sparse_qr_extract_r_f64")?;
let cfg = launch_config((grid_size(count as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&work);
builder.arg(&count);
builder.arg(&output);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_extract_r_f64", e))?;
Ok(())
}
pub unsafe fn launch_sparse_qr_clear_f32(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
work: u64,
n: i32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "sparse_qr_clear_f32")?;
let cfg = launch_config((grid_size(n as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&work);
builder.arg(&n);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_clear_f32", e))?;
Ok(())
}
pub unsafe fn launch_sparse_qr_clear_f64(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
work: u64,
n: i32,
) -> Result<()> {
let module = get_or_load_module(context, device_index, SPARSE_LINALG_MODULE)?;
let func = get_kernel_function(&module, "sparse_qr_clear_f64")?;
let cfg = launch_config((grid_size(n as u32), 1, 1), (BLOCK_SIZE, 1, 1), 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&work);
builder.arg(&n);
unsafe { builder.launch(cfg) }.map_err(|e| launch_error("sparse_qr_clear_f64", e))?;
Ok(())
}