Skip to main content

ComputeBackend

Trait ComputeBackend 

Source
pub trait ComputeBackend:
    Send
    + Sync
    + Debug {
Show 15 methods // Required methods fn name(&self) -> &str; fn init(&mut self) -> BackendResult<()>; fn is_initialized(&self) -> bool; fn gemm( &self, trans_a: BackendTranspose, trans_b: BackendTranspose, m: usize, n: usize, k: usize, alpha: f64, a_ptr: u64, lda: usize, b_ptr: u64, ldb: usize, beta: f64, c_ptr: u64, ldc: usize, ) -> BackendResult<()>; fn conv2d_forward( &self, input_ptr: u64, input_shape: &[usize], filter_ptr: u64, filter_shape: &[usize], output_ptr: u64, output_shape: &[usize], stride: &[usize], padding: &[usize], ) -> BackendResult<()>; fn attention( &self, q_ptr: u64, k_ptr: u64, v_ptr: u64, o_ptr: u64, batch: usize, heads: usize, seq_q: usize, seq_kv: usize, head_dim: usize, scale: f64, causal: bool, ) -> BackendResult<()>; fn reduce( &self, op: ReduceOp, input_ptr: u64, output_ptr: u64, shape: &[usize], axis: usize, ) -> BackendResult<()>; fn unary( &self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize, ) -> BackendResult<()>; fn binary( &self, op: BinaryOp, a_ptr: u64, b_ptr: u64, output_ptr: u64, n: usize, ) -> BackendResult<()>; fn synchronize(&self) -> BackendResult<()>; fn alloc(&self, bytes: usize) -> BackendResult<u64>; fn free(&self, ptr: u64) -> BackendResult<()>; fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()>; fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()>; // Provided method fn batched_gemm( &self, trans_a: BackendTranspose, trans_b: BackendTranspose, m: usize, n: usize, k: usize, alpha: f64, a_ptr: u64, lda: usize, stride_a: usize, b_ptr: u64, ldb: usize, stride_b: usize, beta: f64, c_ptr: u64, ldc: usize, stride_c: usize, batch_count: usize, ) -> BackendResult<()> { ... }
}
Expand description

Abstract compute backend trait.

Implementations provide GPU-accelerated compute operations. All operations work with opaque device memory pointers (u64) and explicit shape/stride information, making the trait independent of any particular memory management scheme.

§Object Safety

This trait is object-safe and can be used as Box<dyn ComputeBackend> or &dyn ComputeBackend for dynamic dispatch.

§Lifecycle

  1. Create the backend (CudaBackend::new()).
  2. Call init to select a device and create a context.
  3. Allocate memory with alloc.
  4. Transfer data with copy_htod.
  5. Run compute operations (gemm, conv2d_forward, etc.).
  6. Read results with copy_dtoh.
  7. Free memory with free.

Required Methods§

Source

fn name(&self) -> &str

Backend name (e.g., "cuda", "rocm", "metal").

Source

fn init(&mut self) -> BackendResult<()>

Initialize the backend (select device, create context).

Must be called before any other operation. Calling init on an already-initialized backend is a no-op.

Source

fn is_initialized(&self) -> bool

Returns true if the backend is ready for operations.

Source

fn gemm( &self, trans_a: BackendTranspose, trans_b: BackendTranspose, m: usize, n: usize, k: usize, alpha: f64, a_ptr: u64, lda: usize, b_ptr: u64, ldb: usize, beta: f64, c_ptr: u64, ldc: usize, ) -> BackendResult<()>

General matrix multiply: C = alpha * op(A) * op(B) + beta * C.

§Arguments
  • trans_a, trans_b — transpose modes for A and B.
  • m, n, k — matrix dimensions (C is m×n, A is m×k, B is k×n after transpose).
  • alpha, beta — scaling factors.
  • a_ptr, b_ptr, c_ptr — device pointers to column-major f64 matrices.
  • lda, ldb, ldc — leading dimensions.
Source

fn conv2d_forward( &self, input_ptr: u64, input_shape: &[usize], filter_ptr: u64, filter_shape: &[usize], output_ptr: u64, output_shape: &[usize], stride: &[usize], padding: &[usize], ) -> BackendResult<()>

2D convolution forward pass.

§Arguments
  • input_ptr — device pointer to input tensor (NCHW layout).
  • input_shape[N, C, H, W].
  • filter_ptr — device pointer to filter tensor.
  • filter_shape[K, C, Fh, Fw].
  • output_ptr — device pointer to output tensor.
  • output_shape[N, K, Oh, Ow].
  • stride[sh, sw].
  • padding[ph, pw].
Source

fn attention( &self, q_ptr: u64, k_ptr: u64, v_ptr: u64, o_ptr: u64, batch: usize, heads: usize, seq_q: usize, seq_kv: usize, head_dim: usize, scale: f64, causal: bool, ) -> BackendResult<()>

Scaled dot-product attention.

Computes softmax(Q * K^T / scale) * V with optional causal masking.

§Arguments
  • q_ptr, k_ptr, v_ptr — device pointers to query, key, value tensors.
  • o_ptr — device pointer to output tensor.
  • batch, heads — batch size and number of attention heads.
  • seq_q, seq_kv — query and key/value sequence lengths.
  • head_dim — dimension of each attention head.
  • scale — attention scale factor (typically 1 / sqrt(head_dim)).
  • causal — if true, apply causal (lower-triangular) mask.
Source

fn reduce( &self, op: ReduceOp, input_ptr: u64, output_ptr: u64, shape: &[usize], axis: usize, ) -> BackendResult<()>

Reduction along an axis.

Reduces input along axis using the specified op and writes to output.

Source

fn unary( &self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize, ) -> BackendResult<()>

Element-wise unary operation.

Applies op to each of the n elements at input_ptr and writes to output_ptr.

Source

fn binary( &self, op: BinaryOp, a_ptr: u64, b_ptr: u64, output_ptr: u64, n: usize, ) -> BackendResult<()>

Element-wise binary operation.

Applies op element-wise: output[i] = op(a[i], b[i]) for n elements.

Source

fn synchronize(&self) -> BackendResult<()>

Synchronize all pending operations on this backend.

Blocks the host until all previously submitted GPU work completes.

Source

fn alloc(&self, bytes: usize) -> BackendResult<u64>

Allocate device memory.

Returns an opaque device pointer. The caller is responsible for eventually calling free.

Source

fn free(&self, ptr: u64) -> BackendResult<()>

Free device memory previously allocated with alloc.

Source

fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()>

Copy data from host memory to device memory.

  • dst — device pointer (destination).
  • src — host byte slice (source).
Source

fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()>

Copy data from device memory to host memory.

  • dst — host byte slice (destination).
  • src — device pointer (source).

Provided Methods§

Source

fn batched_gemm( &self, trans_a: BackendTranspose, trans_b: BackendTranspose, m: usize, n: usize, k: usize, alpha: f64, a_ptr: u64, lda: usize, stride_a: usize, b_ptr: u64, ldb: usize, stride_b: usize, beta: f64, c_ptr: u64, ldc: usize, stride_c: usize, batch_count: usize, ) -> BackendResult<()>

Strided batched GEMM: for each batch b in 0..batch_count, compute C_b = alpha * op(A_b) * op(B_b) + beta * C_b where A_b starts at a_ptr + b * stride_a * 4 bytes (f32 elements), etc.

§Arguments
  • trans_a, trans_b — transpose modes for A and B.
  • m, n, k — matrix dimensions (C is m×n).
  • alpha, beta — scaling factors.
  • a_ptr, b_ptr, c_ptr — device pointers to the first matrix in each batch.
  • lda, ldb, ldc — leading dimensions.
  • stride_a, stride_b, stride_c — element strides between consecutive matrices.
  • batch_count — number of GEMM operations in the batch.

The default implementation dispatches batch_count individual gemm calls with pointer offsets.

Implementors§