use std::sync::Arc;
use oxicuda_blas::BlasHandle;
use oxicuda_driver::{Context, Stream};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::cache::PtxCache;
use crate::error::{SolverError, SolverResult};
pub struct SolverHandle {
context: Arc<Context>,
stream: Stream,
blas_handle: BlasHandle,
ptx_cache: PtxCache,
sm_version: SmVersion,
workspace: DeviceBuffer<u8>,
}
impl SolverHandle {
pub fn new(ctx: &Arc<Context>) -> SolverResult<Self> {
let blas_handle = BlasHandle::new(ctx)?;
let sm_version = blas_handle.sm_version();
let stream = Stream::new(ctx)?;
let ptx_cache = PtxCache::new()
.map_err(|e| SolverError::InternalError(format!("failed to create PTX cache: {e}")))?;
let workspace = DeviceBuffer::<u8>::zeroed(4096)?;
Ok(Self {
context: Arc::clone(ctx),
stream,
blas_handle,
ptx_cache,
sm_version,
workspace,
})
}
pub fn ensure_workspace(&mut self, bytes: usize) -> SolverResult<()> {
if self.workspace.len() < bytes {
self.workspace = DeviceBuffer::<u8>::zeroed(bytes)?;
}
Ok(())
}
pub fn blas(&self) -> &BlasHandle {
&self.blas_handle
}
pub fn blas_mut(&mut self) -> &mut BlasHandle {
&mut self.blas_handle
}
pub fn stream(&self) -> &Stream {
&self.stream
}
pub fn context(&self) -> &Arc<Context> {
&self.context
}
pub fn sm_version(&self) -> SmVersion {
self.sm_version
}
pub fn ptx_cache(&self) -> &PtxCache {
&self.ptx_cache
}
pub fn ptx_cache_mut(&mut self) -> &mut PtxCache {
&mut self.ptx_cache
}
pub fn workspace(&self) -> &DeviceBuffer<u8> {
&self.workspace
}
pub fn workspace_mut(&mut self) -> &mut DeviceBuffer<u8> {
&mut self.workspace
}
}
#[cfg(test)]
mod tests {
#[test]
fn initial_workspace_size() {
assert_eq!(4096, 4096);
}
#[test]
fn workspace_requirement_logic() {
let current = 4096_usize;
let required = 8192_usize;
assert!(current < required, "should need reallocation");
}
#[test]
fn workspace_sufficient_logic() {
let current = 8192_usize;
let required = 4096_usize;
assert!(current >= required, "should not need reallocation");
}
}