pub trait ComputeBackend:
Send
+ Sync
+ Debug {
Show 15 methods
// Required methods
fn name(&self) -> &str;
fn init(&mut self) -> BackendResult<()>;
fn is_initialized(&self) -> bool;
fn gemm(
&self,
trans_a: BackendTranspose,
trans_b: BackendTranspose,
m: usize,
n: usize,
k: usize,
alpha: f64,
a_ptr: u64,
lda: usize,
b_ptr: u64,
ldb: usize,
beta: f64,
c_ptr: u64,
ldc: usize,
) -> BackendResult<()>;
fn conv2d_forward(
&self,
input_ptr: u64,
input_shape: &[usize],
filter_ptr: u64,
filter_shape: &[usize],
output_ptr: u64,
output_shape: &[usize],
stride: &[usize],
padding: &[usize],
) -> BackendResult<()>;
fn attention(
&self,
q_ptr: u64,
k_ptr: u64,
v_ptr: u64,
o_ptr: u64,
batch: usize,
heads: usize,
seq_q: usize,
seq_kv: usize,
head_dim: usize,
scale: f64,
causal: bool,
) -> BackendResult<()>;
fn reduce(
&self,
op: ReduceOp,
input_ptr: u64,
output_ptr: u64,
shape: &[usize],
axis: usize,
) -> BackendResult<()>;
fn unary(
&self,
op: UnaryOp,
input_ptr: u64,
output_ptr: u64,
n: usize,
) -> BackendResult<()>;
fn binary(
&self,
op: BinaryOp,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
n: usize,
) -> BackendResult<()>;
fn synchronize(&self) -> BackendResult<()>;
fn alloc(&self, bytes: usize) -> BackendResult<u64>;
fn free(&self, ptr: u64) -> BackendResult<()>;
fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()>;
fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()>;
// Provided method
fn batched_gemm(
&self,
trans_a: BackendTranspose,
trans_b: BackendTranspose,
m: usize,
n: usize,
k: usize,
alpha: f64,
a_ptr: u64,
lda: usize,
stride_a: usize,
b_ptr: u64,
ldb: usize,
stride_b: usize,
beta: f64,
c_ptr: u64,
ldc: usize,
stride_c: usize,
batch_count: usize,
) -> BackendResult<()> { ... }
}Expand description
Abstract compute backend trait.
Implementations provide GPU-accelerated compute operations.
All operations work with opaque device memory pointers (u64)
and explicit shape/stride information, making the trait
independent of any particular memory management scheme.
§Object Safety
This trait is object-safe and can be used as Box<dyn ComputeBackend>
or &dyn ComputeBackend for dynamic dispatch.
§Lifecycle
Required Methods§
Sourcefn init(&mut self) -> BackendResult<()>
fn init(&mut self) -> BackendResult<()>
Initialize the backend (select device, create context).
Must be called before any other operation. Calling init on an
already-initialized backend is a no-op.
Sourcefn is_initialized(&self) -> bool
fn is_initialized(&self) -> bool
Returns true if the backend is ready for operations.
Sourcefn gemm(
&self,
trans_a: BackendTranspose,
trans_b: BackendTranspose,
m: usize,
n: usize,
k: usize,
alpha: f64,
a_ptr: u64,
lda: usize,
b_ptr: u64,
ldb: usize,
beta: f64,
c_ptr: u64,
ldc: usize,
) -> BackendResult<()>
fn gemm( &self, trans_a: BackendTranspose, trans_b: BackendTranspose, m: usize, n: usize, k: usize, alpha: f64, a_ptr: u64, lda: usize, b_ptr: u64, ldb: usize, beta: f64, c_ptr: u64, ldc: usize, ) -> BackendResult<()>
General matrix multiply: C = alpha * op(A) * op(B) + beta * C.
§Arguments
trans_a,trans_b— transpose modes for A and B.m,n,k— matrix dimensions (C is m×n, A is m×k, B is k×n after transpose).alpha,beta— scaling factors.a_ptr,b_ptr,c_ptr— device pointers to column-major f64 matrices.lda,ldb,ldc— leading dimensions.
Sourcefn conv2d_forward(
&self,
input_ptr: u64,
input_shape: &[usize],
filter_ptr: u64,
filter_shape: &[usize],
output_ptr: u64,
output_shape: &[usize],
stride: &[usize],
padding: &[usize],
) -> BackendResult<()>
fn conv2d_forward( &self, input_ptr: u64, input_shape: &[usize], filter_ptr: u64, filter_shape: &[usize], output_ptr: u64, output_shape: &[usize], stride: &[usize], padding: &[usize], ) -> BackendResult<()>
2D convolution forward pass.
§Arguments
input_ptr— device pointer to input tensor (NCHW layout).input_shape—[N, C, H, W].filter_ptr— device pointer to filter tensor.filter_shape—[K, C, Fh, Fw].output_ptr— device pointer to output tensor.output_shape—[N, K, Oh, Ow].stride—[sh, sw].padding—[ph, pw].
Sourcefn attention(
&self,
q_ptr: u64,
k_ptr: u64,
v_ptr: u64,
o_ptr: u64,
batch: usize,
heads: usize,
seq_q: usize,
seq_kv: usize,
head_dim: usize,
scale: f64,
causal: bool,
) -> BackendResult<()>
fn attention( &self, q_ptr: u64, k_ptr: u64, v_ptr: u64, o_ptr: u64, batch: usize, heads: usize, seq_q: usize, seq_kv: usize, head_dim: usize, scale: f64, causal: bool, ) -> BackendResult<()>
Scaled dot-product attention.
Computes softmax(Q * K^T / scale) * V with optional causal masking.
§Arguments
q_ptr,k_ptr,v_ptr— device pointers to query, key, value tensors.o_ptr— device pointer to output tensor.batch,heads— batch size and number of attention heads.seq_q,seq_kv— query and key/value sequence lengths.head_dim— dimension of each attention head.scale— attention scale factor (typically1 / sqrt(head_dim)).causal— iftrue, apply causal (lower-triangular) mask.
Sourcefn reduce(
&self,
op: ReduceOp,
input_ptr: u64,
output_ptr: u64,
shape: &[usize],
axis: usize,
) -> BackendResult<()>
fn reduce( &self, op: ReduceOp, input_ptr: u64, output_ptr: u64, shape: &[usize], axis: usize, ) -> BackendResult<()>
Reduction along an axis.
Reduces input along axis using the specified op and writes to output.
Sourcefn unary(
&self,
op: UnaryOp,
input_ptr: u64,
output_ptr: u64,
n: usize,
) -> BackendResult<()>
fn unary( &self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize, ) -> BackendResult<()>
Element-wise unary operation.
Applies op to each of the n elements at input_ptr and writes to output_ptr.
Sourcefn binary(
&self,
op: BinaryOp,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
n: usize,
) -> BackendResult<()>
fn binary( &self, op: BinaryOp, a_ptr: u64, b_ptr: u64, output_ptr: u64, n: usize, ) -> BackendResult<()>
Element-wise binary operation.
Applies op element-wise: output[i] = op(a[i], b[i]) for n elements.
Sourcefn synchronize(&self) -> BackendResult<()>
fn synchronize(&self) -> BackendResult<()>
Synchronize all pending operations on this backend.
Blocks the host until all previously submitted GPU work completes.
Sourcefn alloc(&self, bytes: usize) -> BackendResult<u64>
fn alloc(&self, bytes: usize) -> BackendResult<u64>
Allocate device memory.
Returns an opaque device pointer. The caller is responsible for
eventually calling free.
Sourcefn free(&self, ptr: u64) -> BackendResult<()>
fn free(&self, ptr: u64) -> BackendResult<()>
Free device memory previously allocated with alloc.
Provided Methods§
Sourcefn batched_gemm(
&self,
trans_a: BackendTranspose,
trans_b: BackendTranspose,
m: usize,
n: usize,
k: usize,
alpha: f64,
a_ptr: u64,
lda: usize,
stride_a: usize,
b_ptr: u64,
ldb: usize,
stride_b: usize,
beta: f64,
c_ptr: u64,
ldc: usize,
stride_c: usize,
batch_count: usize,
) -> BackendResult<()>
fn batched_gemm( &self, trans_a: BackendTranspose, trans_b: BackendTranspose, m: usize, n: usize, k: usize, alpha: f64, a_ptr: u64, lda: usize, stride_a: usize, b_ptr: u64, ldb: usize, stride_b: usize, beta: f64, c_ptr: u64, ldc: usize, stride_c: usize, batch_count: usize, ) -> BackendResult<()>
Strided batched GEMM: for each batch b in 0..batch_count,
compute C_b = alpha * op(A_b) * op(B_b) + beta * C_b
where A_b starts at a_ptr + b * stride_a * 4 bytes (f32 elements), etc.
§Arguments
trans_a,trans_b— transpose modes for A and B.m,n,k— matrix dimensions (C is m×n).alpha,beta— scaling factors.a_ptr,b_ptr,c_ptr— device pointers to the first matrix in each batch.lda,ldb,ldc— leading dimensions.stride_a,stride_b,stride_c— element strides between consecutive matrices.batch_count— number of GEMM operations in the batch.
The default implementation dispatches batch_count individual
gemm calls with pointer offsets.