use std::sync::Arc;
use atomr_core::actor::{Context, Props};
use atomr_macros::Actor;
use cudarc::cublas::sys::cublasOperation_t;
use cudarc::cublas::{CudaBlas, Gemm, GemmConfig};
use crate::completion::{CompletionStrategy, HostFnCompletion};
use crate::device::{DeviceState, SgemmRequest};
use crate::error::GpuError;
use crate::kernel::envelope;
use crate::stream::{ActorHints, StreamAllocator};
pub enum BlasMsg {
Sgemm(Box<SgemmRequest>),
}
#[derive(Actor)]
#[msg(BlasMsg)]
pub struct BlasActor {
inner: BlasInner,
}
enum BlasInner {
Real {
blas: CudaBlas,
stream: Arc<cudarc::driver::CudaStream>,
completion: Arc<dyn CompletionStrategy>,
#[allow(dead_code)]
state: Arc<DeviceState>,
},
Mock,
}
const LIB: &str = "cublas";
impl BlasActor {
pub fn props(
stream: Arc<cudarc::driver::CudaStream>,
allocator: Arc<dyn StreamAllocator>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
) -> Props<Self> {
let actor_stream = allocator.acquire(ActorHints::default());
debug_assert!(Arc::ptr_eq(&actor_stream, &stream));
Props::create(move || {
let blas = match CudaBlas::new(stream.clone()) {
Ok(b) => b,
Err(e) => panic!("ContextPoisoned: CudaBlas::new failed: {e}"),
};
BlasActor {
inner: BlasInner::Real {
blas,
stream: stream.clone(),
completion: completion.clone(),
state: state.clone(),
},
}
})
}
pub fn props_legacy(
stream: Arc<cudarc::driver::CudaStream>,
allocator: crate::stream::PerActorAllocator,
completion: HostFnCompletion,
state: Arc<DeviceState>,
) -> Props<Self> {
let alloc: Arc<dyn StreamAllocator> = Arc::new(allocator);
let comp: Arc<dyn CompletionStrategy> = Arc::new(completion);
Self::props(stream, alloc, comp, state)
}
pub fn mock_props() -> Props<Self> {
Props::create(|| BlasActor {
inner: BlasInner::Mock,
})
}
}
impl BlasActor {
async fn handle_msg(&mut self, _ctx: &mut Context<Self>, msg: BlasMsg) {
match msg {
BlasMsg::Sgemm(req) => match &self.inner {
BlasInner::Mock => {
let _ = req.reply.send(Err(GpuError::Unrecoverable(
"Sgemm not supported in mock mode".into(),
)));
}
BlasInner::Real {
blas,
stream,
completion,
..
} => {
enqueue_sgemm(blas, stream, completion, *req);
}
},
}
}
}
fn enqueue_sgemm(
blas: &CudaBlas,
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
req: SgemmRequest,
) {
let SgemmRequest {
a,
b,
c,
m,
n,
k,
alpha,
beta,
reply,
} = req;
let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
Ok(t) => t,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let cfg = GemmConfig::<f32> {
transa: cublasOperation_t::CUBLAS_OP_N,
transb: cublasOperation_t::CUBLAS_OP_N,
m,
n,
k,
alpha,
lda: m,
ldb: k,
beta,
ldc: m,
};
let mut c_owned = match Arc::try_unwrap(c_slice) {
Ok(s) => s,
Err(_arc) => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"SGEMM target buffer C has more than one live reference; \
caller must hold the unique GpuRef to write to it"
.into(),
)));
return;
}
};
c.record_write(stream);
envelope::run_kernel(LIB, stream, completion, (), reply, move || {
let res = unsafe { blas.gemm(cfg, &*a_slice, &*b_slice, &mut c_owned) };
match res {
Ok(()) => {
Ok((a_slice, b_slice, c_owned))
}
Err(e) => Err(GpuError::LibraryError {
lib: LIB,
msg: format!("sgemm enqueue: {e}"),
}),
}
});
}