use std::sync::Arc;
use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
use cudarc::cutensor::result as ct_result;
use cudarc::cutensor::sys as ct_sys;
use parking_lot::Mutex;
use tokio::sync::oneshot;
use crate::completion::CompletionStrategy;
use crate::device::DeviceState;
use crate::error::GpuError;
use crate::kernel::dispatch::{TensorDispatch, TensorDispatchCtx, WorkspacePool};
use crate::stream::StreamAllocator;
#[cfg(feature = "cutensor-autotune")]
pub mod autotune;
pub mod compute_desc;
pub mod contract;
pub mod elementwise;
pub mod permute;
pub mod plan_cache;
pub mod reduce;
pub use compute_desc::ComputeDesc;
pub use contract::{ContractRequest, OperandSpec};
pub use elementwise::{ElementwiseBinaryRequest, ElementwiseTrinaryRequest};
pub use permute::PermutationRequest;
pub use plan_cache::{PlanCache, PlanKey, DEFAULT_PLAN_CACHE_SIZE};
pub use reduce::ReductionRequest;
pub type TensorSpec = OperandSpec<f32>;
pub struct SendHandle(pub ct_sys::cutensorHandle_t);
unsafe impl Send for SendHandle {}
unsafe impl Sync for SendHandle {}
pub enum TensorMsg {
Op(Box<dyn TensorDispatch>),
#[deprecated(note = "use TensorMsg::Op(Box::new(ContractRequest::<f32>::new(...)))")]
Contract {
a: TensorSpec,
b: TensorSpec,
c: TensorSpec,
alpha: f32,
beta: f32,
reply: oneshot::Sender<Result<(), GpuError>>,
},
}
pub struct TensorActor {
inner: TensorInner,
}
#[allow(clippy::large_enum_variant)]
enum TensorInner {
Real {
ctx: TensorDispatchCtx,
#[allow(dead_code)]
state: Arc<DeviceState>,
},
Mock,
}
impl Drop for TensorInner {
fn drop(&mut self) {
if let TensorInner::Real { ctx, .. } = self {
let h = ctx.handle.lock();
unsafe {
let _ = ct_result::destroy_handle(h.0);
}
}
}
}
impl TensorActor {
pub fn props(
stream: Arc<cudarc::driver::CudaStream>,
_allocator: Arc<dyn StreamAllocator>,
completion: Arc<dyn CompletionStrategy>,
state: Arc<DeviceState>,
) -> Props<Self> {
Props::create(move || {
let h = match ct_result::create_handle() {
Ok(h) => h,
Err(e) => panic!("ContextPoisoned: cutensorCreate failed: {e}"),
};
let ctx = TensorDispatchCtx {
handle: Arc::new(Mutex::new(SendHandle(h))),
stream: stream.clone(),
completion: completion.clone(),
plan_cache: Arc::new(PlanCache::with_default_capacity()),
workspace: Arc::new(WorkspacePool::new(stream.clone())),
};
TensorActor {
inner: TensorInner::Real {
ctx,
state: state.clone(),
},
}
})
}
pub fn mock_props() -> Props<Self> {
Props::create(|| TensorActor {
inner: TensorInner::Mock,
})
}
}
#[async_trait]
impl Actor for TensorActor {
type Msg = TensorMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: TensorMsg) {
match &self.inner {
TensorInner::Mock => mock_reply(msg),
TensorInner::Real { ctx, .. } => match msg {
TensorMsg::Op(req) => req.dispatch(ctx),
#[allow(deprecated)]
TensorMsg::Contract {
a,
b,
c,
alpha,
beta,
reply,
} => {
let req = ContractRequest::<f32>::new(a, b, c, alpha, beta, reply);
Box::new(req).dispatch(ctx);
}
},
}
}
}
fn mock_reply(msg: TensorMsg) {
match msg {
TensorMsg::Op(req) => req.fail_mock(),
#[allow(deprecated)]
TensorMsg::Contract { reply, .. } => {
let _ = reply.send(Err(GpuError::Unrecoverable(
"TensorActor in mock mode".into(),
)));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deprecated_contract_alias_still_constructs() {
let (tx_op, rx_op) = oneshot::channel();
let mock = MockReq { reply: Some(tx_op) };
let msg_op = TensorMsg::Op(Box::new(mock));
mock_reply(msg_op);
let res = rx_op
.blocking_recv()
.expect("Op mock_reply must send a result");
assert!(matches!(res, Err(GpuError::Unrecoverable(_))));
legacy_contract_mock_path();
}
#[allow(deprecated)]
#[allow(dead_code)]
fn legacy_contract_mock_path() {
let _build: fn(
TensorSpec,
TensorSpec,
TensorSpec,
f32,
f32,
oneshot::Sender<Result<(), GpuError>>,
) -> TensorMsg = |a, b, c, alpha, beta, reply| TensorMsg::Contract {
a,
b,
c,
alpha,
beta,
reply,
};
}
struct MockReq {
reply: Option<oneshot::Sender<Result<(), GpuError>>>,
}
impl TensorDispatch for MockReq {
fn op_tag(&self) -> &'static str {
"mock"
}
fn dtype_tag(&self) -> &'static str {
"mock"
}
fn dispatch(self: Box<Self>, _ctx: &TensorDispatchCtx) {}
fn fail_mock(mut self: Box<Self>) {
if let Some(tx) = self.reply.take() {
let _ = tx.send(Err(GpuError::Unrecoverable(
"TensorActor in mock mode".into(),
)));
}
}
}
}