use std::sync::Arc;
use oxicuda_blas::BlasHandle;
use oxicuda_driver::{Context, Stream};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::cache::PtxCache;
use crate::error::{DnnError, DnnResult};
pub struct DnnHandle {
context: Arc<Context>,
stream: Stream,
blas_handle: BlasHandle,
ptx_cache: PtxCache,
sm_version: SmVersion,
workspace: Option<DeviceBuffer<u8>>,
}
impl DnnHandle {
pub fn new(ctx: &Arc<Context>) -> DnnResult<Self> {
let stream = Stream::new(ctx)?;
Self::build(ctx, stream)
}
pub fn with_stream(ctx: &Arc<Context>, stream: Stream) -> DnnResult<Self> {
Self::build(ctx, stream)
}
fn build(ctx: &Arc<Context>, stream: Stream) -> DnnResult<Self> {
let device = ctx.device();
let (major, minor) = device.compute_capability()?;
let sm_version = SmVersion::from_compute_capability(major, minor).ok_or_else(|| {
DnnError::UnsupportedOperation(format!(
"unsupported compute capability: {major}.{minor}"
))
})?;
let blas_stream = Stream::new(ctx)?;
let blas_handle = BlasHandle::with_stream(ctx, blas_stream)?;
let ptx_cache = PtxCache::new()?;
Ok(Self {
context: Arc::clone(ctx),
stream,
blas_handle,
ptx_cache,
sm_version,
workspace: None,
})
}
#[inline]
pub fn context(&self) -> &Arc<Context> {
&self.context
}
#[inline]
pub fn stream(&self) -> &Stream {
&self.stream
}
#[inline]
pub fn blas(&self) -> &BlasHandle {
&self.blas_handle
}
#[inline]
pub fn blas_mut(&mut self) -> &mut BlasHandle {
&mut self.blas_handle
}
#[inline]
pub fn sm_version(&self) -> SmVersion {
self.sm_version
}
#[inline]
pub fn ptx_cache(&self) -> &PtxCache {
&self.ptx_cache
}
#[inline]
pub fn workspace(&self) -> Option<&DeviceBuffer<u8>> {
self.workspace.as_ref()
}
#[inline]
pub fn workspace_mut(&mut self) -> Option<&mut DeviceBuffer<u8>> {
self.workspace.as_mut()
}
pub fn set_stream(&mut self, stream: Stream) {
self.stream = stream;
}
pub fn set_workspace(&mut self, bytes: usize) -> DnnResult<()> {
if bytes == 0 {
return Err(DnnError::InvalidArgument(
"workspace size must be non-zero".into(),
));
}
if let Some(ref ws) = self.workspace {
if ws.len() >= bytes {
return Ok(());
}
}
let buf = DeviceBuffer::<u8>::alloc(bytes)?;
self.workspace = Some(buf);
Ok(())
}
pub fn clear_workspace(&mut self) {
self.workspace = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sm_version_is_copy() {
let v = SmVersion::Sm80;
let v2 = v;
assert_eq!(v, v2);
}
}