use std::sync::Arc;
use atomr_core::actor::{Context, Props};
use atomr_macros::Actor;
use cudarc::cublas::CudaBlas;
use crate::completion::{CompletionStrategy, HostFnCompletion};
use crate::device::{DeviceState, SgemmRequest};
use crate::error::GpuError;
use crate::kernel::dispatch::{
BlasDispatchCtx, BlasL1Dispatch, BlasL2Dispatch, BlasL3Dispatch, GemmDispatch,
GemmStridedBatchedDispatch,
};
use crate::stream::{ActorHints, StreamAllocator};
pub mod gemm;
pub mod gemm_strided_batched;
pub mod l1;
pub mod l2;
pub mod l3;
pub mod scaling;
pub use gemm::GemmRequest;
pub use gemm_strided_batched::GemmStridedBatchedRequest;
pub use l1::{
AsumRequest, AxpyRequest, CopyRequest, DotRequest, IamaxRequest, IaminRequest, Nrm2Request,
RotRequest, ScalRequest, SwapRequest,
};
pub use l2::{GemvRequest, GerRequest};
pub use l3::{GeamRequest, SyrkRequest, TrsmRequest};
pub enum BlasMsg {
Gemm(Box<dyn GemmDispatch>),
L1(Box<dyn BlasL1Dispatch>),
L2(Box<dyn BlasL2Dispatch>),
L3(Box<dyn BlasL3Dispatch>),
GemmStridedBatched(Box<dyn GemmStridedBatchedDispatch>),
#[deprecated(note = "use BlasMsg::gemm::<f32>(GemmRequest::<f32> { ... })")]
Sgemm(Box<crate::device::SgemmRequest>),
}
impl BlasMsg {
pub fn gemm<T: crate::dtype::GemmSupported>(req: GemmRequest<T>) -> Self
where
GemmRequest<T>: GemmDispatch,
{
Self::Gemm(Box::new(req))
}
pub fn gemm_strided_batched<T: crate::dtype::GemmSupported>(
req: GemmStridedBatchedRequest<T>,
) -> Self
where
GemmStridedBatchedRequest<T>: GemmStridedBatchedDispatch,
{
Self::GemmStridedBatched(Box::new(req))
}
}
#[derive(Actor)]
#[msg(BlasMsg)]
pub struct BlasActor {
inner: BlasInner,
}
pub(crate) enum BlasInner {
Real {
cublas: Arc<CudaBlas>,
stream: Arc<cudarc::driver::CudaStream>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
},
Mock,
}
impl BlasActor {
pub fn props(
stream: Arc<cudarc::driver::CudaStream>,
allocator: Arc<dyn StreamAllocator>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
) -> Props<Self> {
let actor_stream = allocator.acquire(ActorHints::default());
debug_assert!(Arc::ptr_eq(&actor_stream, &stream));
Props::create(move || {
let cublas = match CudaBlas::new(stream.clone()) {
Ok(b) => b,
Err(e) => panic!("ContextPoisoned: CudaBlas::new failed: {e}"),
};
BlasActor {
inner: BlasInner::Real {
cublas: Arc::new(cublas),
stream: stream.clone(),
completion: completion.clone(),
state: state.clone(),
},
}
})
}
pub fn props_legacy(
stream: Arc<cudarc::driver::CudaStream>,
allocator: crate::stream::PerActorAllocator,
completion: HostFnCompletion,
state: Arc<DeviceState>,
) -> Props<Self> {
let alloc: Arc<dyn StreamAllocator> = Arc::new(allocator);
let comp: Arc<dyn CompletionStrategy> = Arc::new(completion);
Self::props(stream, alloc, comp, state)
}
pub fn mock_props() -> Props<Self> {
Props::create(|| BlasActor {
inner: BlasInner::Mock,
})
}
}
impl BlasActor {
async fn handle_msg(&mut self, _ctx: &mut Context<Self>, msg: BlasMsg) {
match &self.inner {
BlasInner::Mock => match msg {
BlasMsg::Gemm(d) => mock_reply(d.op_name()),
BlasMsg::L1(d) => mock_reply(d.op_name()),
BlasMsg::L2(d) => mock_reply(d.op_name()),
BlasMsg::L3(d) => mock_reply(d.op_name()),
BlasMsg::GemmStridedBatched(d) => mock_reply(d.op_name()),
#[allow(deprecated)]
BlasMsg::Sgemm(req) => {
let _ = req.reply.send(Err(GpuError::Unrecoverable(
"Sgemm not supported in mock mode".into(),
)));
}
},
BlasInner::Real {
cublas,
stream,
completion,
state,
} => {
let ctx = BlasDispatchCtx {
cublas,
stream,
completion,
state,
};
match msg {
BlasMsg::Gemm(d) => d.dispatch(&ctx),
BlasMsg::L1(d) => d.dispatch(&ctx),
BlasMsg::L2(d) => d.dispatch(&ctx),
BlasMsg::L3(d) => d.dispatch(&ctx),
BlasMsg::GemmStridedBatched(d) => d.dispatch(&ctx),
#[allow(deprecated)]
BlasMsg::Sgemm(req) => {
let SgemmRequest {
a,
b,
c,
m,
n,
k,
alpha,
beta,
reply,
} = *req;
let typed = GemmRequest::<f32> {
a,
b,
c,
m,
n,
k,
alpha,
beta,
trans_a: cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
trans_b: cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N,
lda: m,
ldb: k,
ldc: m,
reply,
};
let boxed: Box<dyn GemmDispatch> = Box::new(typed);
boxed.dispatch(&ctx);
}
}
}
}
}
}
fn mock_reply(op: &'static str) {
tracing::debug!(op, "BlasActor (mock): dropping op without reply");
}