Skip to main content

oxicuda_backend/
lib.rs

1//! Abstract compute backend for GPU-accelerated operations.
2//!
3//! The [`ComputeBackend`] trait defines the interface for GPU computation,
4//! allowing higher-level crates (SciRS2, oxionnx, ToRSh, TrustformeRS)
5//! to use GPU acceleration without coupling to specific GPU APIs.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────┐
11//! │  SciRS2 / ToRSh / oxionnx  │
12//! │         (consumers)         │
13//! └─────────────┬───────────────┘
14//!               │  dyn ComputeBackend
15//! ┌─────────────▼───────────────┐
16//! │       ComputeBackend        │
17//! │     (trait definition)      │
18//! └─────────────┬───────────────┘
19//!               │
20//! ┌─────────────▼───────────────┐
21//! │  CudaBackend / MetalBackend │
22//! │    (concrete impls)         │
23//! └─────────────────────────────┘
24//! ```
25
26use std::fmt;
27
28// ─── Error types ────────────────────────────────────────────
29
30/// Error type for backend operations.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum BackendError {
33    /// The requested operation is not supported by this backend.
34    Unsupported(String),
35    /// A GPU/device error occurred.
36    DeviceError(String),
37    /// Invalid argument to an operation.
38    InvalidArgument(String),
39    /// Out of device memory.
40    OutOfMemory,
41    /// Backend not initialized — call [`ComputeBackend::init`] first.
42    NotInitialized,
43}
44
45impl fmt::Display for BackendError {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            Self::Unsupported(msg) => write!(f, "unsupported operation: {msg}"),
49            Self::DeviceError(msg) => write!(f, "device error: {msg}"),
50            Self::InvalidArgument(msg) => write!(f, "invalid argument: {msg}"),
51            Self::OutOfMemory => write!(f, "out of device memory"),
52            Self::NotInitialized => write!(f, "backend not initialized"),
53        }
54    }
55}
56
57impl std::error::Error for BackendError {}
58
59/// Result type for backend operations.
60pub type BackendResult<T> = Result<T, BackendError>;
61
62// ─── Operation enums ────────────────────────────────────────
63
64/// Transpose mode for matrix operations.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
66pub enum BackendTranspose {
67    /// No transpose — use the matrix as-is.
68    NoTrans,
69    /// Transpose (swap rows and columns).
70    Trans,
71    /// Conjugate transpose (Hermitian).
72    ConjTrans,
73}
74
75impl fmt::Display for BackendTranspose {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        match self {
78            Self::NoTrans => write!(f, "N"),
79            Self::Trans => write!(f, "T"),
80            Self::ConjTrans => write!(f, "C"),
81        }
82    }
83}
84
85/// Reduction operation applied along an axis.
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87pub enum ReduceOp {
88    /// Summation.
89    Sum,
90    /// Maximum value.
91    Max,
92    /// Minimum value.
93    Min,
94    /// Arithmetic mean.
95    Mean,
96}
97
98impl fmt::Display for ReduceOp {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        match self {
101            Self::Sum => write!(f, "sum"),
102            Self::Max => write!(f, "max"),
103            Self::Min => write!(f, "min"),
104            Self::Mean => write!(f, "mean"),
105        }
106    }
107}
108
109/// Element-wise unary operation.
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
111pub enum UnaryOp {
112    /// Rectified linear unit: max(0, x).
113    Relu,
114    /// Sigmoid: 1 / (1 + exp(-x)).
115    Sigmoid,
116    /// Hyperbolic tangent.
117    Tanh,
118    /// Exponential.
119    Exp,
120    /// Natural logarithm.
121    Log,
122    /// Square root.
123    Sqrt,
124    /// Absolute value.
125    Abs,
126    /// Negation.
127    Neg,
128}
129
130impl fmt::Display for UnaryOp {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        match self {
133            Self::Relu => write!(f, "relu"),
134            Self::Sigmoid => write!(f, "sigmoid"),
135            Self::Tanh => write!(f, "tanh"),
136            Self::Exp => write!(f, "exp"),
137            Self::Log => write!(f, "log"),
138            Self::Sqrt => write!(f, "sqrt"),
139            Self::Abs => write!(f, "abs"),
140            Self::Neg => write!(f, "neg"),
141        }
142    }
143}
144
145/// Element-wise binary operation.
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
147pub enum BinaryOp {
148    /// Addition.
149    Add,
150    /// Subtraction.
151    Sub,
152    /// Multiplication.
153    Mul,
154    /// Division.
155    Div,
156    /// Element-wise maximum.
157    Max,
158    /// Element-wise minimum.
159    Min,
160}
161
162impl fmt::Display for BinaryOp {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        match self {
165            Self::Add => write!(f, "add"),
166            Self::Sub => write!(f, "sub"),
167            Self::Mul => write!(f, "mul"),
168            Self::Div => write!(f, "div"),
169            Self::Max => write!(f, "max"),
170            Self::Min => write!(f, "min"),
171        }
172    }
173}
174
175// ─── ComputeBackend trait ───────────────────────────────────
176
177/// Abstract compute backend trait.
178///
179/// Implementations provide GPU-accelerated compute operations.
180/// All operations work with opaque device memory pointers (`u64`)
181/// and explicit shape/stride information, making the trait
182/// independent of any particular memory management scheme.
183///
184/// # Object Safety
185///
186/// This trait is object-safe and can be used as `Box<dyn ComputeBackend>`
187/// or `&dyn ComputeBackend` for dynamic dispatch.
188///
189/// # Lifecycle
190///
191/// 1. Create the backend (`CudaBackend::new()`).
192/// 2. Call [`init`](ComputeBackend::init) to select a device and create a context.
193/// 3. Allocate memory with [`alloc`](ComputeBackend::alloc).
194/// 4. Transfer data with [`copy_htod`](ComputeBackend::copy_htod).
195/// 5. Run compute operations ([`gemm`](ComputeBackend::gemm), [`conv2d_forward`](ComputeBackend::conv2d_forward), etc.).
196/// 6. Read results with [`copy_dtoh`](ComputeBackend::copy_dtoh).
197/// 7. Free memory with [`free`](ComputeBackend::free).
198pub trait ComputeBackend: Send + Sync + fmt::Debug {
199    /// Backend name (e.g., `"cuda"`, `"rocm"`, `"metal"`).
200    fn name(&self) -> &str;
201
202    /// Initialize the backend (select device, create context).
203    ///
204    /// Must be called before any other operation. Calling `init` on an
205    /// already-initialized backend is a no-op.
206    fn init(&mut self) -> BackendResult<()>;
207
208    /// Returns `true` if the backend is ready for operations.
209    fn is_initialized(&self) -> bool;
210
211    /// General matrix multiply: `C = alpha * op(A) * op(B) + beta * C`.
212    ///
213    /// # Arguments
214    ///
215    /// * `trans_a`, `trans_b` — transpose modes for A and B.
216    /// * `m`, `n`, `k` — matrix dimensions (C is m×n, A is m×k, B is k×n after transpose).
217    /// * `alpha`, `beta` — scaling factors.
218    /// * `a_ptr`, `b_ptr`, `c_ptr` — device pointers to column-major f64 matrices.
219    /// * `lda`, `ldb`, `ldc` — leading dimensions.
220    #[allow(clippy::too_many_arguments)]
221    fn gemm(
222        &self,
223        trans_a: BackendTranspose,
224        trans_b: BackendTranspose,
225        m: usize,
226        n: usize,
227        k: usize,
228        alpha: f64,
229        a_ptr: u64,
230        lda: usize,
231        b_ptr: u64,
232        ldb: usize,
233        beta: f64,
234        c_ptr: u64,
235        ldc: usize,
236    ) -> BackendResult<()>;
237
238    /// 2D convolution forward pass.
239    ///
240    /// # Arguments
241    ///
242    /// * `input_ptr` — device pointer to input tensor (NCHW layout).
243    /// * `input_shape` — `[N, C, H, W]`.
244    /// * `filter_ptr` — device pointer to filter tensor.
245    /// * `filter_shape` — `[K, C, Fh, Fw]`.
246    /// * `output_ptr` — device pointer to output tensor.
247    /// * `output_shape` — `[N, K, Oh, Ow]`.
248    /// * `stride` — `[sh, sw]`.
249    /// * `padding` — `[ph, pw]`.
250    #[allow(clippy::too_many_arguments)]
251    fn conv2d_forward(
252        &self,
253        input_ptr: u64,
254        input_shape: &[usize],
255        filter_ptr: u64,
256        filter_shape: &[usize],
257        output_ptr: u64,
258        output_shape: &[usize],
259        stride: &[usize],
260        padding: &[usize],
261    ) -> BackendResult<()>;
262
263    /// Scaled dot-product attention.
264    ///
265    /// Computes `softmax(Q * K^T / scale) * V` with optional causal masking.
266    ///
267    /// # Arguments
268    ///
269    /// * `q_ptr`, `k_ptr`, `v_ptr` — device pointers to query, key, value tensors.
270    /// * `o_ptr` — device pointer to output tensor.
271    /// * `batch`, `heads` — batch size and number of attention heads.
272    /// * `seq_q`, `seq_kv` — query and key/value sequence lengths.
273    /// * `head_dim` — dimension of each attention head.
274    /// * `scale` — attention scale factor (typically `1 / sqrt(head_dim)`).
275    /// * `causal` — if `true`, apply causal (lower-triangular) mask.
276    #[allow(clippy::too_many_arguments)]
277    fn attention(
278        &self,
279        q_ptr: u64,
280        k_ptr: u64,
281        v_ptr: u64,
282        o_ptr: u64,
283        batch: usize,
284        heads: usize,
285        seq_q: usize,
286        seq_kv: usize,
287        head_dim: usize,
288        scale: f64,
289        causal: bool,
290    ) -> BackendResult<()>;
291
292    /// Reduction along an axis.
293    ///
294    /// Reduces `input` along `axis` using the specified `op` and writes to `output`.
295    fn reduce(
296        &self,
297        op: ReduceOp,
298        input_ptr: u64,
299        output_ptr: u64,
300        shape: &[usize],
301        axis: usize,
302    ) -> BackendResult<()>;
303
304    /// Element-wise unary operation.
305    ///
306    /// Applies `op` to each of the `n` elements at `input_ptr` and writes to `output_ptr`.
307    fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()>;
308
309    /// Element-wise binary operation.
310    ///
311    /// Applies `op` element-wise: `output[i] = op(a[i], b[i])` for `n` elements.
312    fn binary(
313        &self,
314        op: BinaryOp,
315        a_ptr: u64,
316        b_ptr: u64,
317        output_ptr: u64,
318        n: usize,
319    ) -> BackendResult<()>;
320
321    /// Synchronize all pending operations on this backend.
322    ///
323    /// Blocks the host until all previously submitted GPU work completes.
324    fn synchronize(&self) -> BackendResult<()>;
325
326    /// Allocate device memory.
327    ///
328    /// Returns an opaque device pointer. The caller is responsible for
329    /// eventually calling [`free`](ComputeBackend::free).
330    fn alloc(&self, bytes: usize) -> BackendResult<u64>;
331
332    /// Free device memory previously allocated with [`alloc`](ComputeBackend::alloc).
333    fn free(&self, ptr: u64) -> BackendResult<()>;
334
335    /// Copy data from host memory to device memory.
336    ///
337    /// * `dst` — device pointer (destination).
338    /// * `src` — host byte slice (source).
339    fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()>;
340
341    /// Copy data from device memory to host memory.
342    ///
343    /// * `dst` — host byte slice (destination).
344    /// * `src` — device pointer (source).
345    fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()>;
346}
347
348// ─── Tests ──────────────────────────────────────────────────
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn backend_error_display() {
356        assert_eq!(
357            BackendError::Unsupported("foo".into()).to_string(),
358            "unsupported operation: foo"
359        );
360        assert_eq!(
361            BackendError::DeviceError("bar".into()).to_string(),
362            "device error: bar"
363        );
364        assert_eq!(
365            BackendError::InvalidArgument("baz".into()).to_string(),
366            "invalid argument: baz"
367        );
368        assert_eq!(
369            BackendError::OutOfMemory.to_string(),
370            "out of device memory"
371        );
372        assert_eq!(
373            BackendError::NotInitialized.to_string(),
374            "backend not initialized"
375        );
376    }
377
378    #[test]
379    fn backend_error_is_std_error() {
380        let err: Box<dyn std::error::Error> = Box::new(BackendError::DeviceError("test".into()));
381        assert!(err.to_string().contains("test"));
382    }
383
384    #[test]
385    fn backend_transpose_display_and_values() {
386        assert_eq!(BackendTranspose::NoTrans.to_string(), "N");
387        assert_eq!(BackendTranspose::Trans.to_string(), "T");
388        assert_eq!(BackendTranspose::ConjTrans.to_string(), "C");
389
390        // Equality
391        assert_eq!(BackendTranspose::NoTrans, BackendTranspose::NoTrans);
392        assert_ne!(BackendTranspose::NoTrans, BackendTranspose::Trans);
393    }
394
395    #[test]
396    fn reduce_op_display_and_coverage() {
397        let ops = [ReduceOp::Sum, ReduceOp::Max, ReduceOp::Min, ReduceOp::Mean];
398        let names = ["sum", "max", "min", "mean"];
399        for (op, name) in ops.iter().zip(names.iter()) {
400            assert_eq!(op.to_string(), *name);
401        }
402    }
403
404    #[test]
405    fn unary_op_display_and_coverage() {
406        let ops = [
407            UnaryOp::Relu,
408            UnaryOp::Sigmoid,
409            UnaryOp::Tanh,
410            UnaryOp::Exp,
411            UnaryOp::Log,
412            UnaryOp::Sqrt,
413            UnaryOp::Abs,
414            UnaryOp::Neg,
415        ];
416        let names = [
417            "relu", "sigmoid", "tanh", "exp", "log", "sqrt", "abs", "neg",
418        ];
419        for (op, name) in ops.iter().zip(names.iter()) {
420            assert_eq!(op.to_string(), *name);
421        }
422    }
423
424    #[test]
425    fn binary_op_display_and_coverage() {
426        let ops = [
427            BinaryOp::Add,
428            BinaryOp::Sub,
429            BinaryOp::Mul,
430            BinaryOp::Div,
431            BinaryOp::Max,
432            BinaryOp::Min,
433        ];
434        let names = ["add", "sub", "mul", "div", "max", "min"];
435        for (op, name) in ops.iter().zip(names.iter()) {
436            assert_eq!(op.to_string(), *name);
437        }
438    }
439
440    #[test]
441    fn enum_clone_and_hash() {
442        use std::collections::HashSet;
443
444        let mut set = HashSet::new();
445        set.insert(ReduceOp::Sum);
446        set.insert(ReduceOp::Max);
447        assert!(set.contains(&ReduceOp::Sum));
448        assert!(!set.contains(&ReduceOp::Min));
449
450        // Clone
451        let op = UnaryOp::Relu;
452        let cloned = op;
453        assert_eq!(op, cloned);
454
455        let bop = BinaryOp::Add;
456        let bcloned = bop;
457        assert_eq!(bop, bcloned);
458
459        let trans = BackendTranspose::ConjTrans;
460        let tcloned = trans;
461        assert_eq!(trans, tcloned);
462    }
463}