atomr-accel-cuda 0.10.0

GPU acceleration via the actor model. Wraps NVIDIA CUDA libraries (cuBLAS, cuDNN, cuFFT, cuRAND, cuSOLVER, cuSPARSE, cuTENSOR, cuBLASLt, NVRTC, NCCL) as supervised atomr actors with generation-validated buffers and a uniform async surface.
Documentation
//! `TensorActor` — wraps cuTENSOR for contractions, reductions,
//! permutations, and binary/trinary elementwise ops.
//!
//! cuTENSOR's safe `cudarc::cutensor::result` layer covers
//! contractions and reductions only. The remaining ops drop down to
//! `cudarc::cutensor::sys` via the local
//! [`crate::sys::cutensor`] wrappers.
//!
//! The actor is a single mailbox carrying `TensorMsg::Op(Box<dyn
//! TensorDispatch>)`. Each typed request — `ContractRequest<T>`,
//! `ReductionRequest<T>`, `ElementwiseBinaryRequest<T>`,
//! `ElementwiseTrinaryRequest<T>`, `PermutationRequest<T>` — implements
//! [`TensorDispatch`](crate::kernel::dispatch::TensorDispatch) so it
//! erases through a single mailbox without mailbox-per-dtype blowup.

use std::sync::Arc;

use async_trait::async_trait;
use atomr_core::actor::{Actor, Context, Props};
use cudarc::cutensor::result as ct_result;
use cudarc::cutensor::sys as ct_sys;
use parking_lot::Mutex;
use tokio::sync::oneshot;

use crate::completion::CompletionStrategy;
use crate::device::DeviceState;
use crate::error::GpuError;
use crate::kernel::dispatch::{TensorDispatch, TensorDispatchCtx, WorkspacePool};
use crate::stream::StreamAllocator;

#[cfg(feature = "cutensor-autotune")]
pub mod autotune;
pub mod compute_desc;
pub mod contract;
pub mod elementwise;
pub mod permute;
pub mod plan_cache;
pub mod reduce;

pub use compute_desc::ComputeDesc;
pub use contract::{ContractRequest, OperandSpec};
pub use elementwise::{ElementwiseBinaryRequest, ElementwiseTrinaryRequest};
pub use permute::PermutationRequest;
pub use plan_cache::{PlanCache, PlanKey, DEFAULT_PLAN_CACHE_SIZE};
pub use reduce::ReductionRequest;

/// Pre-Phase-2 spec: alias of `OperandSpec<f32>` so existing call
/// sites (`tests/contract_e2e.rs`, downstream users) keep compiling.
/// New code should prefer `OperandSpec<T>` directly.
pub type TensorSpec = OperandSpec<f32>;

/// Newtype around `cutensorHandle_t` so it can be stored in `Arc<Mutex>`.
/// Manually `Send`/`Sync` because the wrapped pointer is opaque and
/// cuTENSOR's docs guarantee thread-safety of the handle when callers
/// serialise access (the `Mutex` does that).
pub struct SendHandle(pub ct_sys::cutensorHandle_t);
unsafe impl Send for SendHandle {}
unsafe impl Sync for SendHandle {}

/// `TensorActor` mailbox. New requests should use [`TensorMsg::Op`]
/// with a `Box<dyn TensorDispatch>` — the `Contract` variant is kept
/// as a deprecated thin alias for back-compat with the F-phase API.
pub enum TensorMsg {
    /// Type-erased dtype-generic request.
    Op(Box<dyn TensorDispatch>),
    /// Legacy f32-only contraction. Constructs a
    /// `ContractRequest<f32>` internally and routes it through the
    /// same dispatch path. Deprecated.
    #[deprecated(note = "use TensorMsg::Op(Box::new(ContractRequest::<f32>::new(...)))")]
    Contract {
        a: TensorSpec,
        b: TensorSpec,
        c: TensorSpec,
        alpha: f32,
        beta: f32,
        reply: oneshot::Sender<Result<(), GpuError>>,
    },
}

pub struct TensorActor {
    inner: TensorInner,
}

#[allow(clippy::large_enum_variant)]
enum TensorInner {
    Real {
        ctx: TensorDispatchCtx,
        #[allow(dead_code)]
        state: Arc<DeviceState>,
    },
    Mock,
}

impl Drop for TensorInner {
    fn drop(&mut self) {
        if let TensorInner::Real { ctx, .. } = self {
            let h = ctx.handle.lock();
            unsafe {
                let _ = ct_result::destroy_handle(h.0);
            }
        }
    }
}

impl TensorActor {
    pub fn props(
        stream: Arc<cudarc::driver::CudaStream>,
        _allocator: Arc<dyn StreamAllocator>,
        completion: Arc<dyn CompletionStrategy>,
        state: Arc<DeviceState>,
    ) -> Props<Self> {
        Props::create(move || {
            let h = match ct_result::create_handle() {
                Ok(h) => h,
                Err(e) => panic!("ContextPoisoned: cutensorCreate failed: {e}"),
            };
            let ctx = TensorDispatchCtx {
                handle: Arc::new(Mutex::new(SendHandle(h))),
                stream: stream.clone(),
                completion: completion.clone(),
                plan_cache: Arc::new(PlanCache::with_default_capacity()),
                workspace: Arc::new(WorkspacePool::new(stream.clone())),
            };
            TensorActor {
                inner: TensorInner::Real {
                    ctx,
                    state: state.clone(),
                },
            }
        })
    }

    pub fn mock_props() -> Props<Self> {
        Props::create(|| TensorActor {
            inner: TensorInner::Mock,
        })
    }
}

#[async_trait]
impl Actor for TensorActor {
    type Msg = TensorMsg;

    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: TensorMsg) {
        match &self.inner {
            TensorInner::Mock => mock_reply(msg),
            TensorInner::Real { ctx, .. } => match msg {
                TensorMsg::Op(req) => req.dispatch(ctx),
                #[allow(deprecated)]
                TensorMsg::Contract {
                    a,
                    b,
                    c,
                    alpha,
                    beta,
                    reply,
                } => {
                    let req = ContractRequest::<f32>::new(a, b, c, alpha, beta, reply);
                    Box::new(req).dispatch(ctx);
                }
            },
        }
    }
}

fn mock_reply(msg: TensorMsg) {
    match msg {
        TensorMsg::Op(req) => req.fail_mock(),
        #[allow(deprecated)]
        TensorMsg::Contract { reply, .. } => {
            let _ = reply.send(Err(GpuError::Unrecoverable(
                "TensorActor in mock mode".into(),
            )));
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// The deprecated `Contract` variant must still construct via
    /// pattern matching — proves the back-compat alias compiles after
    /// the refactor. We can't dispatch it without a GPU, but we can
    /// confirm the variant arms exist and the mock-mode reply path
    /// fires through both shapes (`Op(Box<dyn TensorDispatch>)` and
    /// the legacy `Contract { ... }`).
    #[test]
    fn deprecated_contract_alias_still_constructs() {
        // Op(Box<dyn TensorDispatch>) path.
        let (tx_op, rx_op) = oneshot::channel();
        let mock = MockReq { reply: Some(tx_op) };
        let msg_op = TensorMsg::Op(Box::new(mock));
        mock_reply(msg_op);
        let res = rx_op
            .blocking_recv()
            .expect("Op mock_reply must send a result");
        assert!(matches!(res, Err(GpuError::Unrecoverable(_))));

        // Legacy Contract path: `mock_reply` must still match the
        // variant and fire its `reply`. We can't materialise a real
        // GpuRef, so build the variant via a destructuring trick that
        // bypasses TensorSpec construction — instead we use the
        // higher-level discriminant check: ensure the variant arm in
        // `mock_reply` is reachable by pattern matching against a
        // shape we manually construct via a dedicated factory.
        legacy_contract_mock_path();
    }

    /// Materialises a `TensorMsg::Contract { ... }` value in a way
    /// that requires only the public surface — we cannot build a
    /// real `GpuRef<f32>` host-side, so we leave it to the GPU
    /// integration test. This function compiles only if the variant
    /// is still present, which is the back-compat guarantee under
    /// test.
    #[allow(deprecated)]
    #[allow(dead_code)]
    fn legacy_contract_mock_path() {
        // Type-level check: ensure `TensorMsg::Contract` still has
        // the documented field layout. We use a closure that, if
        // ever invoked, would build the variant.
        let _build: fn(
            TensorSpec,
            TensorSpec,
            TensorSpec,
            f32,
            f32,
            oneshot::Sender<Result<(), GpuError>>,
        ) -> TensorMsg = |a, b, c, alpha, beta, reply| TensorMsg::Contract {
            a,
            b,
            c,
            alpha,
            beta,
            reply,
        };
    }

    struct MockReq {
        reply: Option<oneshot::Sender<Result<(), GpuError>>>,
    }
    impl TensorDispatch for MockReq {
        fn op_tag(&self) -> &'static str {
            "mock"
        }
        fn dtype_tag(&self) -> &'static str {
            "mock"
        }
        fn dispatch(self: Box<Self>, _ctx: &TensorDispatchCtx) {}
        fn fail_mock(mut self: Box<Self>) {
            if let Some(tx) = self.reply.take() {
                let _ = tx.send(Err(GpuError::Unrecoverable(
                    "TensorActor in mock mode".into(),
                )));
            }
        }
    }
}