Skip to main content

atomr_accel/
gpu_ref.rs

1//! `AccelRef<T, B>` — backend-agnostic typed device pointer.
2//!
3//! Each backend defines its own concrete buffer type
4//! (`cudarc::driver::CudaSlice<T>` for CUDA, `hip-sys` slices for
5//! ROCm, `MTLBuffer` for Metal, …). This module declares the
6//! generation-validated wrapper contract every backend's concrete
7//! `*Ref<T>` type satisfies.
8//!
9//! Backends that want to share more shape than this trait offers
10//! are encouraged to ship a `pub type AccelRef<T> = MyConcreteRef<T>;`
11//! re-export so application code can pattern-match on the concrete
12//! type when needed.
13
14use std::marker::PhantomData;
15
16use crate::backend::AccelBackend;
17use crate::error::AccelError;
18
19/// Trait implemented by every backend's typed-pointer wrapper.
20///
21/// The generation token check is the contract: each backend's
22/// `access()` returns `Err(AccelError::AccelRefStale)` if the
23/// device generation has advanced past the one the ref was minted
24/// against. Code that walks `AccelRef`s never has to know which
25/// backend is underneath.
26pub trait AccelRef<T, B: AccelBackend>: Clone + Send + Sync + 'static {
27    /// Number of `T` elements in the buffer.
28    fn len(&self) -> usize;
29
30    /// Returns true if `len() == 0`.
31    fn is_empty(&self) -> bool {
32        self.len() == 0
33    }
34
35    /// Generation token captured at allocation time. Backends mint
36    /// fresh refs against `device.generation()` and validate the
37    /// match on every `access()`.
38    fn generation(&self) -> u64;
39
40    /// Originating device id. Used by multi-device routing to
41    /// reject cross-device misuse (e.g. AllReduce input mismatch).
42    fn device_id(&self) -> Option<u32>;
43
44    /// Validate the ref is still usable. Returns `Err` if the
45    /// device generation has moved or the device is shutting down.
46    fn check(&self) -> Result<(), AccelError>;
47}
48
49/// Marker struct so portable code can reference an
50/// "abstract `AccelRef<T>`" without committing to a backend.
51/// Concrete backends usually expose their own typedef
52/// (e.g. `atomr_accel_cuda::GpuRef<T>`).
53pub struct AnyRef<T, B: AccelBackend> {
54    _phantom: PhantomData<(T, B)>,
55}