Skip to main content

oxicuda_backend/
lib.rs

1//! Abstract compute backend for GPU-accelerated operations.
2//!
3//! The [`ComputeBackend`] trait defines the interface for GPU computation,
4//! allowing higher-level crates (SciRS2, oxionnx, ToRSh, TrustformeRS)
5//! to use GPU acceleration without coupling to specific GPU APIs.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────┐
11//! │  SciRS2 / ToRSh / oxionnx  │
12//! │         (consumers)         │
13//! └─────────────┬───────────────┘
14//!               │  dyn ComputeBackend
15//! ┌─────────────▼───────────────┐
16//! │       ComputeBackend        │
17//! │     (trait definition)      │
18//! └─────────────┬───────────────┘
19//!               │
20//! ┌─────────────▼───────────────┐
21//! │  CudaBackend / MetalBackend │
22//! │    (concrete impls)         │
23//! └─────────────────────────────┘
24//! ```
25
26use std::fmt;
27
28// ─── Error types ────────────────────────────────────────────
29
30/// Error type for backend operations.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum BackendError {
33    /// The requested operation is not supported by this backend.
34    Unsupported(String),
35    /// A GPU/device error occurred.
36    DeviceError(String),
37    /// Invalid argument to an operation.
38    InvalidArgument(String),
39    /// Out of device memory.
40    OutOfMemory,
41    /// Backend not initialized — call [`ComputeBackend::init`] first.
42    NotInitialized,
43}
44
45impl fmt::Display for BackendError {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            Self::Unsupported(msg) => write!(f, "unsupported operation: {msg}"),
49            Self::DeviceError(msg) => write!(f, "device error: {msg}"),
50            Self::InvalidArgument(msg) => write!(f, "invalid argument: {msg}"),
51            Self::OutOfMemory => write!(f, "out of device memory"),
52            Self::NotInitialized => write!(f, "backend not initialized"),
53        }
54    }
55}
56
57impl std::error::Error for BackendError {}
58
59/// Result type for backend operations.
60pub type BackendResult<T> = Result<T, BackendError>;
61
62// ─── Operation enums ────────────────────────────────────────
63
64/// Transpose mode for matrix operations.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
66pub enum BackendTranspose {
67    /// No transpose — use the matrix as-is.
68    NoTrans,
69    /// Transpose (swap rows and columns).
70    Trans,
71    /// Conjugate transpose (Hermitian).
72    ConjTrans,
73}
74
75impl fmt::Display for BackendTranspose {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        match self {
78            Self::NoTrans => write!(f, "N"),
79            Self::Trans => write!(f, "T"),
80            Self::ConjTrans => write!(f, "C"),
81        }
82    }
83}
84
85/// Reduction operation applied along an axis.
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87pub enum ReduceOp {
88    /// Summation.
89    Sum,
90    /// Maximum value.
91    Max,
92    /// Minimum value.
93    Min,
94    /// Arithmetic mean.
95    Mean,
96}
97
98impl fmt::Display for ReduceOp {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        match self {
101            Self::Sum => write!(f, "sum"),
102            Self::Max => write!(f, "max"),
103            Self::Min => write!(f, "min"),
104            Self::Mean => write!(f, "mean"),
105        }
106    }
107}
108
109/// Element-wise unary operation.
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
111pub enum UnaryOp {
112    /// Rectified linear unit: max(0, x).
113    Relu,
114    /// Sigmoid: 1 / (1 + exp(-x)).
115    Sigmoid,
116    /// Hyperbolic tangent.
117    Tanh,
118    /// Exponential.
119    Exp,
120    /// Natural logarithm.
121    Log,
122    /// Square root.
123    Sqrt,
124    /// Absolute value.
125    Abs,
126    /// Negation.
127    Neg,
128}
129
130impl fmt::Display for UnaryOp {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        match self {
133            Self::Relu => write!(f, "relu"),
134            Self::Sigmoid => write!(f, "sigmoid"),
135            Self::Tanh => write!(f, "tanh"),
136            Self::Exp => write!(f, "exp"),
137            Self::Log => write!(f, "log"),
138            Self::Sqrt => write!(f, "sqrt"),
139            Self::Abs => write!(f, "abs"),
140            Self::Neg => write!(f, "neg"),
141        }
142    }
143}
144
145/// Element-wise binary operation.
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
147pub enum BinaryOp {
148    /// Addition.
149    Add,
150    /// Subtraction.
151    Sub,
152    /// Multiplication.
153    Mul,
154    /// Division.
155    Div,
156    /// Element-wise maximum.
157    Max,
158    /// Element-wise minimum.
159    Min,
160}
161
162impl fmt::Display for BinaryOp {
163    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164        match self {
165            Self::Add => write!(f, "add"),
166            Self::Sub => write!(f, "sub"),
167            Self::Mul => write!(f, "mul"),
168            Self::Div => write!(f, "div"),
169            Self::Max => write!(f, "max"),
170            Self::Min => write!(f, "min"),
171        }
172    }
173}
174
175// ─── ComputeBackend trait ───────────────────────────────────
176
177/// Abstract compute backend trait.
178///
179/// Implementations provide GPU-accelerated compute operations.
180/// All operations work with opaque device memory pointers (`u64`)
181/// and explicit shape/stride information, making the trait
182/// independent of any particular memory management scheme.
183///
184/// # Object Safety
185///
186/// This trait is object-safe and can be used as `Box<dyn ComputeBackend>`
187/// or `&dyn ComputeBackend` for dynamic dispatch.
188///
189/// # Lifecycle
190///
191/// 1. Create the backend (`CudaBackend::new()`).
192/// 2. Call [`init`](ComputeBackend::init) to select a device and create a context.
193/// 3. Allocate memory with [`alloc`](ComputeBackend::alloc).
194/// 4. Transfer data with [`copy_htod`](ComputeBackend::copy_htod).
195/// 5. Run compute operations ([`gemm`](ComputeBackend::gemm), [`conv2d_forward`](ComputeBackend::conv2d_forward), etc.).
196/// 6. Read results with [`copy_dtoh`](ComputeBackend::copy_dtoh).
197/// 7. Free memory with [`free`](ComputeBackend::free).
198pub trait ComputeBackend: Send + Sync + fmt::Debug {
199    /// Backend name (e.g., `"cuda"`, `"rocm"`, `"metal"`).
200    fn name(&self) -> &str;
201
202    /// Initialize the backend (select device, create context).
203    ///
204    /// Must be called before any other operation. Calling `init` on an
205    /// already-initialized backend is a no-op.
206    fn init(&mut self) -> BackendResult<()>;
207
208    /// Returns `true` if the backend is ready for operations.
209    fn is_initialized(&self) -> bool;
210
211    /// General matrix multiply: `C = alpha * op(A) * op(B) + beta * C`.
212    ///
213    /// # Arguments
214    ///
215    /// * `trans_a`, `trans_b` — transpose modes for A and B.
216    /// * `m`, `n`, `k` — matrix dimensions (C is m×n, A is m×k, B is k×n after transpose).
217    /// * `alpha`, `beta` — scaling factors.
218    /// * `a_ptr`, `b_ptr`, `c_ptr` — device pointers to column-major f64 matrices.
219    /// * `lda`, `ldb`, `ldc` — leading dimensions.
220    #[allow(clippy::too_many_arguments)]
221    fn gemm(
222        &self,
223        trans_a: BackendTranspose,
224        trans_b: BackendTranspose,
225        m: usize,
226        n: usize,
227        k: usize,
228        alpha: f64,
229        a_ptr: u64,
230        lda: usize,
231        b_ptr: u64,
232        ldb: usize,
233        beta: f64,
234        c_ptr: u64,
235        ldc: usize,
236    ) -> BackendResult<()>;
237
238    /// 2D convolution forward pass.
239    ///
240    /// # Arguments
241    ///
242    /// * `input_ptr` — device pointer to input tensor (NCHW layout).
243    /// * `input_shape` — `[N, C, H, W]`.
244    /// * `filter_ptr` — device pointer to filter tensor.
245    /// * `filter_shape` — `[K, C, Fh, Fw]`.
246    /// * `output_ptr` — device pointer to output tensor.
247    /// * `output_shape` — `[N, K, Oh, Ow]`.
248    /// * `stride` — `[sh, sw]`.
249    /// * `padding` — `[ph, pw]`.
250    #[allow(clippy::too_many_arguments)]
251    fn conv2d_forward(
252        &self,
253        input_ptr: u64,
254        input_shape: &[usize],
255        filter_ptr: u64,
256        filter_shape: &[usize],
257        output_ptr: u64,
258        output_shape: &[usize],
259        stride: &[usize],
260        padding: &[usize],
261    ) -> BackendResult<()>;
262
263    /// Scaled dot-product attention.
264    ///
265    /// Computes `softmax(Q * K^T / scale) * V` with optional causal masking.
266    ///
267    /// # Arguments
268    ///
269    /// * `q_ptr`, `k_ptr`, `v_ptr` — device pointers to query, key, value tensors.
270    /// * `o_ptr` — device pointer to output tensor.
271    /// * `batch`, `heads` — batch size and number of attention heads.
272    /// * `seq_q`, `seq_kv` — query and key/value sequence lengths.
273    /// * `head_dim` — dimension of each attention head.
274    /// * `scale` — attention scale factor (typically `1 / sqrt(head_dim)`).
275    /// * `causal` — if `true`, apply causal (lower-triangular) mask.
276    #[allow(clippy::too_many_arguments)]
277    fn attention(
278        &self,
279        q_ptr: u64,
280        k_ptr: u64,
281        v_ptr: u64,
282        o_ptr: u64,
283        batch: usize,
284        heads: usize,
285        seq_q: usize,
286        seq_kv: usize,
287        head_dim: usize,
288        scale: f64,
289        causal: bool,
290    ) -> BackendResult<()>;
291
292    /// Reduction along an axis.
293    ///
294    /// Reduces `input` along `axis` using the specified `op` and writes to `output`.
295    fn reduce(
296        &self,
297        op: ReduceOp,
298        input_ptr: u64,
299        output_ptr: u64,
300        shape: &[usize],
301        axis: usize,
302    ) -> BackendResult<()>;
303
304    /// Element-wise unary operation.
305    ///
306    /// Applies `op` to each of the `n` elements at `input_ptr` and writes to `output_ptr`.
307    fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()>;
308
309    /// Element-wise binary operation.
310    ///
311    /// Applies `op` element-wise: `output[i] = op(a[i], b[i])` for `n` elements.
312    fn binary(
313        &self,
314        op: BinaryOp,
315        a_ptr: u64,
316        b_ptr: u64,
317        output_ptr: u64,
318        n: usize,
319    ) -> BackendResult<()>;
320
321    /// Strided batched GEMM: for each batch `b` in `0..batch_count`,
322    /// compute `C_b = alpha * op(A_b) * op(B_b) + beta * C_b`
323    /// where `A_b` starts at `a_ptr + b * stride_a * 4` bytes (f32 elements), etc.
324    ///
325    /// # Arguments
326    ///
327    /// * `trans_a`, `trans_b` — transpose modes for A and B.
328    /// * `m`, `n`, `k` — matrix dimensions (C is m×n).
329    /// * `alpha`, `beta` — scaling factors.
330    /// * `a_ptr`, `b_ptr`, `c_ptr` — device pointers to the first matrix in each batch.
331    /// * `lda`, `ldb`, `ldc` — leading dimensions.
332    /// * `stride_a`, `stride_b`, `stride_c` — element strides between consecutive matrices.
333    /// * `batch_count` — number of GEMM operations in the batch.
334    ///
335    /// The default implementation dispatches `batch_count` individual
336    /// [`gemm`](Self::gemm) calls with pointer offsets.
337    #[allow(clippy::too_many_arguments)]
338    fn batched_gemm(
339        &self,
340        trans_a: BackendTranspose,
341        trans_b: BackendTranspose,
342        m: usize,
343        n: usize,
344        k: usize,
345        alpha: f64,
346        a_ptr: u64,
347        lda: usize,
348        stride_a: usize,
349        b_ptr: u64,
350        ldb: usize,
351        stride_b: usize,
352        beta: f64,
353        c_ptr: u64,
354        ldc: usize,
355        stride_c: usize,
356        batch_count: usize,
357    ) -> BackendResult<()> {
358        // Default: loop over individual gemm calls with byte-offset pointers.
359        // Backends should override with a single batched kernel for efficiency.
360        let elem_bytes: u64 = 4; // f32
361        for b in 0..batch_count {
362            let b64 = b as u64;
363            self.gemm(
364                trans_a,
365                trans_b,
366                m,
367                n,
368                k,
369                alpha,
370                a_ptr + b64 * stride_a as u64 * elem_bytes,
371                lda,
372                b_ptr + b64 * stride_b as u64 * elem_bytes,
373                ldb,
374                beta,
375                c_ptr + b64 * stride_c as u64 * elem_bytes,
376                ldc,
377            )?;
378        }
379        Ok(())
380    }
381
382    /// Synchronize all pending operations on this backend.
383    ///
384    /// Blocks the host until all previously submitted GPU work completes.
385    fn synchronize(&self) -> BackendResult<()>;
386
387    /// Allocate device memory.
388    ///
389    /// Returns an opaque device pointer. The caller is responsible for
390    /// eventually calling [`free`](ComputeBackend::free).
391    fn alloc(&self, bytes: usize) -> BackendResult<u64>;
392
393    /// Free device memory previously allocated with [`alloc`](ComputeBackend::alloc).
394    fn free(&self, ptr: u64) -> BackendResult<()>;
395
396    /// Copy data from host memory to device memory.
397    ///
398    /// * `dst` — device pointer (destination).
399    /// * `src` — host byte slice (source).
400    fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()>;
401
402    /// Copy data from device memory to host memory.
403    ///
404    /// * `dst` — host byte slice (destination).
405    /// * `src` — device pointer (source).
406    fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()>;
407}
408
409// ─── Tests ──────────────────────────────────────────────────
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn backend_error_display() {
417        assert_eq!(
418            BackendError::Unsupported("foo".into()).to_string(),
419            "unsupported operation: foo"
420        );
421        assert_eq!(
422            BackendError::DeviceError("bar".into()).to_string(),
423            "device error: bar"
424        );
425        assert_eq!(
426            BackendError::InvalidArgument("baz".into()).to_string(),
427            "invalid argument: baz"
428        );
429        assert_eq!(
430            BackendError::OutOfMemory.to_string(),
431            "out of device memory"
432        );
433        assert_eq!(
434            BackendError::NotInitialized.to_string(),
435            "backend not initialized"
436        );
437    }
438
439    #[test]
440    fn backend_error_is_std_error() {
441        let err: Box<dyn std::error::Error> = Box::new(BackendError::DeviceError("test".into()));
442        assert!(err.to_string().contains("test"));
443    }
444
445    #[test]
446    fn backend_transpose_display_and_values() {
447        assert_eq!(BackendTranspose::NoTrans.to_string(), "N");
448        assert_eq!(BackendTranspose::Trans.to_string(), "T");
449        assert_eq!(BackendTranspose::ConjTrans.to_string(), "C");
450
451        // Equality
452        assert_eq!(BackendTranspose::NoTrans, BackendTranspose::NoTrans);
453        assert_ne!(BackendTranspose::NoTrans, BackendTranspose::Trans);
454    }
455
456    #[test]
457    fn reduce_op_display_and_coverage() {
458        let ops = [ReduceOp::Sum, ReduceOp::Max, ReduceOp::Min, ReduceOp::Mean];
459        let names = ["sum", "max", "min", "mean"];
460        for (op, name) in ops.iter().zip(names.iter()) {
461            assert_eq!(op.to_string(), *name);
462        }
463    }
464
465    #[test]
466    fn unary_op_display_and_coverage() {
467        let ops = [
468            UnaryOp::Relu,
469            UnaryOp::Sigmoid,
470            UnaryOp::Tanh,
471            UnaryOp::Exp,
472            UnaryOp::Log,
473            UnaryOp::Sqrt,
474            UnaryOp::Abs,
475            UnaryOp::Neg,
476        ];
477        let names = [
478            "relu", "sigmoid", "tanh", "exp", "log", "sqrt", "abs", "neg",
479        ];
480        for (op, name) in ops.iter().zip(names.iter()) {
481            assert_eq!(op.to_string(), *name);
482        }
483    }
484
485    #[test]
486    fn binary_op_display_and_coverage() {
487        let ops = [
488            BinaryOp::Add,
489            BinaryOp::Sub,
490            BinaryOp::Mul,
491            BinaryOp::Div,
492            BinaryOp::Max,
493            BinaryOp::Min,
494        ];
495        let names = ["add", "sub", "mul", "div", "max", "min"];
496        for (op, name) in ops.iter().zip(names.iter()) {
497            assert_eq!(op.to_string(), *name);
498        }
499    }
500
501    // ── Mock backend for testing default batched_gemm ──
502
503    use std::sync::atomic::{AtomicUsize, Ordering};
504
505    #[derive(Debug)]
506    struct MockBackend {
507        gemm_call_count: AtomicUsize,
508    }
509
510    impl MockBackend {
511        fn new() -> Self {
512            Self {
513                gemm_call_count: AtomicUsize::new(0),
514            }
515        }
516    }
517
518    impl ComputeBackend for MockBackend {
519        fn name(&self) -> &str {
520            "mock"
521        }
522        fn init(&mut self) -> BackendResult<()> {
523            Ok(())
524        }
525        fn is_initialized(&self) -> bool {
526            true
527        }
528        fn gemm(
529            &self,
530            _trans_a: BackendTranspose,
531            _trans_b: BackendTranspose,
532            _m: usize,
533            _n: usize,
534            _k: usize,
535            _alpha: f64,
536            _a_ptr: u64,
537            _lda: usize,
538            _b_ptr: u64,
539            _ldb: usize,
540            _beta: f64,
541            _c_ptr: u64,
542            _ldc: usize,
543        ) -> BackendResult<()> {
544            self.gemm_call_count.fetch_add(1, Ordering::Relaxed);
545            Ok(())
546        }
547        fn conv2d_forward(
548            &self,
549            _: u64,
550            _: &[usize],
551            _: u64,
552            _: &[usize],
553            _: u64,
554            _: &[usize],
555            _: &[usize],
556            _: &[usize],
557        ) -> BackendResult<()> {
558            Ok(())
559        }
560        fn attention(
561            &self,
562            _: u64,
563            _: u64,
564            _: u64,
565            _: u64,
566            _: usize,
567            _: usize,
568            _: usize,
569            _: usize,
570            _: usize,
571            _: f64,
572            _: bool,
573        ) -> BackendResult<()> {
574            Ok(())
575        }
576        fn reduce(&self, _: ReduceOp, _: u64, _: u64, _: &[usize], _: usize) -> BackendResult<()> {
577            Ok(())
578        }
579        fn unary(&self, _: UnaryOp, _: u64, _: u64, _: usize) -> BackendResult<()> {
580            Ok(())
581        }
582        fn binary(&self, _: BinaryOp, _: u64, _: u64, _: u64, _: usize) -> BackendResult<()> {
583            Ok(())
584        }
585        fn synchronize(&self) -> BackendResult<()> {
586            Ok(())
587        }
588        fn alloc(&self, _: usize) -> BackendResult<u64> {
589            Ok(0)
590        }
591        fn free(&self, _: u64) -> BackendResult<()> {
592            Ok(())
593        }
594        fn copy_htod(&self, _: u64, _: &[u8]) -> BackendResult<()> {
595            Ok(())
596        }
597        fn copy_dtoh(&self, _: &mut [u8], _: u64) -> BackendResult<()> {
598            Ok(())
599        }
600    }
601
602    #[test]
603    fn batched_gemm_zero_batch_is_noop() {
604        let backend = MockBackend::new();
605        let result = backend.batched_gemm(
606            BackendTranspose::NoTrans,
607            BackendTranspose::NoTrans,
608            4,
609            4,
610            4,
611            1.0,
612            0,
613            4,
614            16,
615            0,
616            4,
617            16,
618            0.0,
619            0,
620            4,
621            16,
622            0, // batch_count = 0
623        );
624        assert!(result.is_ok());
625        assert_eq!(backend.gemm_call_count.load(Ordering::Relaxed), 0);
626    }
627
628    #[test]
629    fn batched_gemm_default_calls_gemm_n_times() {
630        let backend = MockBackend::new();
631        let batch_count = 7;
632        let result = backend.batched_gemm(
633            BackendTranspose::NoTrans,
634            BackendTranspose::Trans,
635            8,
636            8,
637            8,
638            1.0,
639            1000,
640            8,
641            64,
642            2000,
643            8,
644            64,
645            0.0,
646            3000,
647            8,
648            64,
649            batch_count,
650        );
651        assert!(result.is_ok());
652        assert_eq!(backend.gemm_call_count.load(Ordering::Relaxed), batch_count);
653    }
654
655    #[test]
656    fn batched_gemm_single_batch() {
657        let backend = MockBackend::new();
658        let result = backend.batched_gemm(
659            BackendTranspose::NoTrans,
660            BackendTranspose::NoTrans,
661            16,
662            16,
663            16,
664            1.0,
665            0,
666            16,
667            256,
668            0,
669            16,
670            256,
671            1.0,
672            0,
673            16,
674            256,
675            1,
676        );
677        assert!(result.is_ok());
678        assert_eq!(backend.gemm_call_count.load(Ordering::Relaxed), 1);
679    }
680
681    #[test]
682    fn enum_clone_and_hash() {
683        use std::collections::HashSet;
684
685        let mut set = HashSet::new();
686        set.insert(ReduceOp::Sum);
687        set.insert(ReduceOp::Max);
688        assert!(set.contains(&ReduceOp::Sum));
689        assert!(!set.contains(&ReduceOp::Min));
690
691        // Clone
692        let op = UnaryOp::Relu;
693        let cloned = op;
694        assert_eq!(op, cloned);
695
696        let bop = BinaryOp::Add;
697        let bcloned = bop;
698        assert_eq!(bop, bcloned);
699
700        let trans = BackendTranspose::ConjTrans;
701        let tcloned = trans;
702        assert_eq!(trans, tcloned);
703    }
704}