use std::mem::size_of;
use std::sync::Arc;
use cudarc::driver::CudaSlice;
use parking_lot::Mutex;
use crate::error::GpuError;
use crate::sys::cusolver::LIB;
pub fn ensure_workspace_bytes(
workspace: &Mutex<Option<CudaSlice<u8>>>,
stream: &Arc<cudarc::driver::CudaStream>,
needed_bytes: usize,
) -> Result<(), GpuError> {
let mut g = workspace.lock();
let cur = g.as_ref().map(|s| s.len()).unwrap_or(0);
if cur >= needed_bytes {
return Ok(());
}
*g =
Some(stream.alloc_zeros::<u8>(needed_bytes).map_err(|e| {
GpuError::OutOfMemory(format!("solver workspace ({needed_bytes}B): {e}"))
})?);
Ok(())
}
pub fn lwork_bytes<T>(lwork: i32) -> usize {
let lwork = lwork.max(0) as usize;
lwork.saturating_mul(size_of::<T>())
}
pub fn check_info(
info: &Mutex<CudaSlice<i32>>,
stream: &Arc<cudarc::driver::CudaStream>,
op: &'static str,
) -> Result<(), GpuError> {
let g = info.lock();
let mut host = vec![0i32; 1];
stream
.memcpy_dtoh(&*g, &mut host[..])
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("{op}: read info: {e}"),
})?;
stream.synchronize().map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("{op}: sync after info: {e}"),
})?;
if host[0] != 0 {
return Err(GpuError::LibraryError {
lib: LIB,
msg: format!("{op}: info={}", host[0]),
});
}
Ok(())
}
pub fn check_info_array(
info: &CudaSlice<i32>,
stream: &Arc<cudarc::driver::CudaStream>,
op: &'static str,
n: usize,
) -> Result<(), GpuError> {
let mut host = vec![0i32; n];
stream
.memcpy_dtoh(info, &mut host[..])
.map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("{op}: read info array: {e}"),
})?;
stream.synchronize().map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("{op}: sync after info array: {e}"),
})?;
if let Some((idx, code)) = host.iter().enumerate().find(|(_, c)| **c != 0) {
return Err(GpuError::LibraryError {
lib: LIB,
msg: format!("{op}: batch[{idx}] info={code}"),
});
}
Ok(())
}