use std::sync::Arc;
use oxicuda_driver::{Context, Stream};
use oxicuda_ptx::arch::SmVersion;
use crate::error::{BlasError, BlasResult};
use crate::types::{MathMode, PointerMode};
pub struct BlasHandle {
context: Arc<Context>,
stream: Stream,
math_mode: MathMode,
pointer_mode: PointerMode,
sm_version: SmVersion,
}
impl BlasHandle {
pub fn new(ctx: &Arc<Context>) -> BlasResult<Self> {
let stream = Stream::new(ctx)?;
Self::build(ctx, stream)
}
pub fn with_stream(ctx: &Arc<Context>, stream: Stream) -> BlasResult<Self> {
Self::build(ctx, stream)
}
fn build(ctx: &Arc<Context>, stream: Stream) -> BlasResult<Self> {
let device = ctx.device();
let (major, minor) = device.compute_capability()?;
let sm_version = SmVersion::from_compute_capability(major, minor).ok_or_else(|| {
BlasError::UnsupportedOperation(format!(
"unsupported compute capability: {major}.{minor}"
))
})?;
Ok(Self {
context: Arc::clone(ctx),
stream,
math_mode: MathMode::Default,
pointer_mode: PointerMode::Host,
sm_version,
})
}
pub fn context(&self) -> &Arc<Context> {
&self.context
}
pub fn stream(&self) -> &Stream {
&self.stream
}
pub fn sm_version(&self) -> SmVersion {
self.sm_version
}
pub fn math_mode(&self) -> MathMode {
self.math_mode
}
pub fn pointer_mode(&self) -> PointerMode {
self.pointer_mode
}
pub fn set_stream(&mut self, stream: Stream) {
self.stream = stream;
}
pub fn set_math_mode(&mut self, mode: MathMode) {
self.math_mode = mode;
}
pub fn set_pointer_mode(&mut self, mode: PointerMode) {
self.pointer_mode = mode;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_modes() {
assert_eq!(MathMode::Default, MathMode::Default);
assert_eq!(PointerMode::Host, PointerMode::Host);
}
}