oxicuda_backend/lib.rs
1//! Abstract compute backend for GPU-accelerated operations.
2//!
3//! The [`ComputeBackend`] trait defines the interface for GPU computation,
4//! allowing higher-level crates (SciRS2, oxionnx, ToRSh, TrustformeRS)
5//! to use GPU acceleration without coupling to specific GPU APIs.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ┌─────────────────────────────┐
11//! │ SciRS2 / ToRSh / oxionnx │
12//! │ (consumers) │
13//! └─────────────┬───────────────┘
14//! │ dyn ComputeBackend
15//! ┌─────────────▼───────────────┐
16//! │ ComputeBackend │
17//! │ (trait definition) │
18//! └─────────────┬───────────────┘
19//! │
20//! ┌─────────────▼───────────────┐
21//! │ CudaBackend / MetalBackend │
22//! │ (concrete impls) │
23//! └─────────────────────────────┘
24//! ```
25
26use std::fmt;
27
28// ─── Error types ────────────────────────────────────────────
29
30/// Error type for backend operations.
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub enum BackendError {
33 /// The requested operation is not supported by this backend.
34 Unsupported(String),
35 /// A GPU/device error occurred.
36 DeviceError(String),
37 /// Invalid argument to an operation.
38 InvalidArgument(String),
39 /// Out of device memory.
40 OutOfMemory,
41 /// Backend not initialized — call [`ComputeBackend::init`] first.
42 NotInitialized,
43}
44
45impl fmt::Display for BackendError {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 Self::Unsupported(msg) => write!(f, "unsupported operation: {msg}"),
49 Self::DeviceError(msg) => write!(f, "device error: {msg}"),
50 Self::InvalidArgument(msg) => write!(f, "invalid argument: {msg}"),
51 Self::OutOfMemory => write!(f, "out of device memory"),
52 Self::NotInitialized => write!(f, "backend not initialized"),
53 }
54 }
55}
56
57impl std::error::Error for BackendError {}
58
59/// Result type for backend operations.
60pub type BackendResult<T> = Result<T, BackendError>;
61
62// ─── Operation enums ────────────────────────────────────────
63
64/// Transpose mode for matrix operations.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
66pub enum BackendTranspose {
67 /// No transpose — use the matrix as-is.
68 NoTrans,
69 /// Transpose (swap rows and columns).
70 Trans,
71 /// Conjugate transpose (Hermitian).
72 ConjTrans,
73}
74
75impl fmt::Display for BackendTranspose {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 match self {
78 Self::NoTrans => write!(f, "N"),
79 Self::Trans => write!(f, "T"),
80 Self::ConjTrans => write!(f, "C"),
81 }
82 }
83}
84
85/// Reduction operation applied along an axis.
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87pub enum ReduceOp {
88 /// Summation.
89 Sum,
90 /// Maximum value.
91 Max,
92 /// Minimum value.
93 Min,
94 /// Arithmetic mean.
95 Mean,
96}
97
98impl fmt::Display for ReduceOp {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 match self {
101 Self::Sum => write!(f, "sum"),
102 Self::Max => write!(f, "max"),
103 Self::Min => write!(f, "min"),
104 Self::Mean => write!(f, "mean"),
105 }
106 }
107}
108
109/// Element-wise unary operation.
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
111pub enum UnaryOp {
112 /// Rectified linear unit: max(0, x).
113 Relu,
114 /// Sigmoid: 1 / (1 + exp(-x)).
115 Sigmoid,
116 /// Hyperbolic tangent.
117 Tanh,
118 /// Exponential.
119 Exp,
120 /// Natural logarithm.
121 Log,
122 /// Square root.
123 Sqrt,
124 /// Absolute value.
125 Abs,
126 /// Negation.
127 Neg,
128}
129
130impl fmt::Display for UnaryOp {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 match self {
133 Self::Relu => write!(f, "relu"),
134 Self::Sigmoid => write!(f, "sigmoid"),
135 Self::Tanh => write!(f, "tanh"),
136 Self::Exp => write!(f, "exp"),
137 Self::Log => write!(f, "log"),
138 Self::Sqrt => write!(f, "sqrt"),
139 Self::Abs => write!(f, "abs"),
140 Self::Neg => write!(f, "neg"),
141 }
142 }
143}
144
145/// Element-wise binary operation.
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
147pub enum BinaryOp {
148 /// Addition.
149 Add,
150 /// Subtraction.
151 Sub,
152 /// Multiplication.
153 Mul,
154 /// Division.
155 Div,
156 /// Element-wise maximum.
157 Max,
158 /// Element-wise minimum.
159 Min,
160}
161
162impl fmt::Display for BinaryOp {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 match self {
165 Self::Add => write!(f, "add"),
166 Self::Sub => write!(f, "sub"),
167 Self::Mul => write!(f, "mul"),
168 Self::Div => write!(f, "div"),
169 Self::Max => write!(f, "max"),
170 Self::Min => write!(f, "min"),
171 }
172 }
173}
174
175// ─── ComputeBackend trait ───────────────────────────────────
176
177/// Abstract compute backend trait.
178///
179/// Implementations provide GPU-accelerated compute operations.
180/// All operations work with opaque device memory pointers (`u64`)
181/// and explicit shape/stride information, making the trait
182/// independent of any particular memory management scheme.
183///
184/// # Object Safety
185///
186/// This trait is object-safe and can be used as `Box<dyn ComputeBackend>`
187/// or `&dyn ComputeBackend` for dynamic dispatch.
188///
189/// # Lifecycle
190///
191/// 1. Create the backend (`CudaBackend::new()`).
192/// 2. Call [`init`](ComputeBackend::init) to select a device and create a context.
193/// 3. Allocate memory with [`alloc`](ComputeBackend::alloc).
194/// 4. Transfer data with [`copy_htod`](ComputeBackend::copy_htod).
195/// 5. Run compute operations ([`gemm`](ComputeBackend::gemm), [`conv2d_forward`](ComputeBackend::conv2d_forward), etc.).
196/// 6. Read results with [`copy_dtoh`](ComputeBackend::copy_dtoh).
197/// 7. Free memory with [`free`](ComputeBackend::free).
198pub trait ComputeBackend: Send + Sync + fmt::Debug {
199 /// Backend name (e.g., `"cuda"`, `"rocm"`, `"metal"`).
200 fn name(&self) -> &str;
201
202 /// Initialize the backend (select device, create context).
203 ///
204 /// Must be called before any other operation. Calling `init` on an
205 /// already-initialized backend is a no-op.
206 fn init(&mut self) -> BackendResult<()>;
207
208 /// Returns `true` if the backend is ready for operations.
209 fn is_initialized(&self) -> bool;
210
211 /// General matrix multiply: `C = alpha * op(A) * op(B) + beta * C`.
212 ///
213 /// # Arguments
214 ///
215 /// * `trans_a`, `trans_b` — transpose modes for A and B.
216 /// * `m`, `n`, `k` — matrix dimensions (C is m×n, A is m×k, B is k×n after transpose).
217 /// * `alpha`, `beta` — scaling factors.
218 /// * `a_ptr`, `b_ptr`, `c_ptr` — device pointers to column-major f64 matrices.
219 /// * `lda`, `ldb`, `ldc` — leading dimensions.
220 #[allow(clippy::too_many_arguments)]
221 fn gemm(
222 &self,
223 trans_a: BackendTranspose,
224 trans_b: BackendTranspose,
225 m: usize,
226 n: usize,
227 k: usize,
228 alpha: f64,
229 a_ptr: u64,
230 lda: usize,
231 b_ptr: u64,
232 ldb: usize,
233 beta: f64,
234 c_ptr: u64,
235 ldc: usize,
236 ) -> BackendResult<()>;
237
238 /// 2D convolution forward pass.
239 ///
240 /// # Arguments
241 ///
242 /// * `input_ptr` — device pointer to input tensor (NCHW layout).
243 /// * `input_shape` — `[N, C, H, W]`.
244 /// * `filter_ptr` — device pointer to filter tensor.
245 /// * `filter_shape` — `[K, C, Fh, Fw]`.
246 /// * `output_ptr` — device pointer to output tensor.
247 /// * `output_shape` — `[N, K, Oh, Ow]`.
248 /// * `stride` — `[sh, sw]`.
249 /// * `padding` — `[ph, pw]`.
250 #[allow(clippy::too_many_arguments)]
251 fn conv2d_forward(
252 &self,
253 input_ptr: u64,
254 input_shape: &[usize],
255 filter_ptr: u64,
256 filter_shape: &[usize],
257 output_ptr: u64,
258 output_shape: &[usize],
259 stride: &[usize],
260 padding: &[usize],
261 ) -> BackendResult<()>;
262
263 /// Scaled dot-product attention.
264 ///
265 /// Computes `softmax(Q * K^T / scale) * V` with optional causal masking.
266 ///
267 /// # Arguments
268 ///
269 /// * `q_ptr`, `k_ptr`, `v_ptr` — device pointers to query, key, value tensors.
270 /// * `o_ptr` — device pointer to output tensor.
271 /// * `batch`, `heads` — batch size and number of attention heads.
272 /// * `seq_q`, `seq_kv` — query and key/value sequence lengths.
273 /// * `head_dim` — dimension of each attention head.
274 /// * `scale` — attention scale factor (typically `1 / sqrt(head_dim)`).
275 /// * `causal` — if `true`, apply causal (lower-triangular) mask.
276 #[allow(clippy::too_many_arguments)]
277 fn attention(
278 &self,
279 q_ptr: u64,
280 k_ptr: u64,
281 v_ptr: u64,
282 o_ptr: u64,
283 batch: usize,
284 heads: usize,
285 seq_q: usize,
286 seq_kv: usize,
287 head_dim: usize,
288 scale: f64,
289 causal: bool,
290 ) -> BackendResult<()>;
291
292 /// Reduction along an axis.
293 ///
294 /// Reduces `input` along `axis` using the specified `op` and writes to `output`.
295 fn reduce(
296 &self,
297 op: ReduceOp,
298 input_ptr: u64,
299 output_ptr: u64,
300 shape: &[usize],
301 axis: usize,
302 ) -> BackendResult<()>;
303
304 /// Element-wise unary operation.
305 ///
306 /// Applies `op` to each of the `n` elements at `input_ptr` and writes to `output_ptr`.
307 fn unary(&self, op: UnaryOp, input_ptr: u64, output_ptr: u64, n: usize) -> BackendResult<()>;
308
309 /// Element-wise binary operation.
310 ///
311 /// Applies `op` element-wise: `output[i] = op(a[i], b[i])` for `n` elements.
312 fn binary(
313 &self,
314 op: BinaryOp,
315 a_ptr: u64,
316 b_ptr: u64,
317 output_ptr: u64,
318 n: usize,
319 ) -> BackendResult<()>;
320
321 /// Synchronize all pending operations on this backend.
322 ///
323 /// Blocks the host until all previously submitted GPU work completes.
324 fn synchronize(&self) -> BackendResult<()>;
325
326 /// Allocate device memory.
327 ///
328 /// Returns an opaque device pointer. The caller is responsible for
329 /// eventually calling [`free`](ComputeBackend::free).
330 fn alloc(&self, bytes: usize) -> BackendResult<u64>;
331
332 /// Free device memory previously allocated with [`alloc`](ComputeBackend::alloc).
333 fn free(&self, ptr: u64) -> BackendResult<()>;
334
335 /// Copy data from host memory to device memory.
336 ///
337 /// * `dst` — device pointer (destination).
338 /// * `src` — host byte slice (source).
339 fn copy_htod(&self, dst: u64, src: &[u8]) -> BackendResult<()>;
340
341 /// Copy data from device memory to host memory.
342 ///
343 /// * `dst` — host byte slice (destination).
344 /// * `src` — device pointer (source).
345 fn copy_dtoh(&self, dst: &mut [u8], src: u64) -> BackendResult<()>;
346}
347
348// ─── Tests ──────────────────────────────────────────────────
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn backend_error_display() {
356 assert_eq!(
357 BackendError::Unsupported("foo".into()).to_string(),
358 "unsupported operation: foo"
359 );
360 assert_eq!(
361 BackendError::DeviceError("bar".into()).to_string(),
362 "device error: bar"
363 );
364 assert_eq!(
365 BackendError::InvalidArgument("baz".into()).to_string(),
366 "invalid argument: baz"
367 );
368 assert_eq!(
369 BackendError::OutOfMemory.to_string(),
370 "out of device memory"
371 );
372 assert_eq!(
373 BackendError::NotInitialized.to_string(),
374 "backend not initialized"
375 );
376 }
377
378 #[test]
379 fn backend_error_is_std_error() {
380 let err: Box<dyn std::error::Error> = Box::new(BackendError::DeviceError("test".into()));
381 assert!(err.to_string().contains("test"));
382 }
383
384 #[test]
385 fn backend_transpose_display_and_values() {
386 assert_eq!(BackendTranspose::NoTrans.to_string(), "N");
387 assert_eq!(BackendTranspose::Trans.to_string(), "T");
388 assert_eq!(BackendTranspose::ConjTrans.to_string(), "C");
389
390 // Equality
391 assert_eq!(BackendTranspose::NoTrans, BackendTranspose::NoTrans);
392 assert_ne!(BackendTranspose::NoTrans, BackendTranspose::Trans);
393 }
394
395 #[test]
396 fn reduce_op_display_and_coverage() {
397 let ops = [ReduceOp::Sum, ReduceOp::Max, ReduceOp::Min, ReduceOp::Mean];
398 let names = ["sum", "max", "min", "mean"];
399 for (op, name) in ops.iter().zip(names.iter()) {
400 assert_eq!(op.to_string(), *name);
401 }
402 }
403
404 #[test]
405 fn unary_op_display_and_coverage() {
406 let ops = [
407 UnaryOp::Relu,
408 UnaryOp::Sigmoid,
409 UnaryOp::Tanh,
410 UnaryOp::Exp,
411 UnaryOp::Log,
412 UnaryOp::Sqrt,
413 UnaryOp::Abs,
414 UnaryOp::Neg,
415 ];
416 let names = [
417 "relu", "sigmoid", "tanh", "exp", "log", "sqrt", "abs", "neg",
418 ];
419 for (op, name) in ops.iter().zip(names.iter()) {
420 assert_eq!(op.to_string(), *name);
421 }
422 }
423
424 #[test]
425 fn binary_op_display_and_coverage() {
426 let ops = [
427 BinaryOp::Add,
428 BinaryOp::Sub,
429 BinaryOp::Mul,
430 BinaryOp::Div,
431 BinaryOp::Max,
432 BinaryOp::Min,
433 ];
434 let names = ["add", "sub", "mul", "div", "max", "min"];
435 for (op, name) in ops.iter().zip(names.iter()) {
436 assert_eq!(op.to_string(), *name);
437 }
438 }
439
440 #[test]
441 fn enum_clone_and_hash() {
442 use std::collections::HashSet;
443
444 let mut set = HashSet::new();
445 set.insert(ReduceOp::Sum);
446 set.insert(ReduceOp::Max);
447 assert!(set.contains(&ReduceOp::Sum));
448 assert!(!set.contains(&ReduceOp::Min));
449
450 // Clone
451 let op = UnaryOp::Relu;
452 let cloned = op;
453 assert_eq!(op, cloned);
454
455 let bop = BinaryOp::Add;
456 let bcloned = bop;
457 assert_eq!(bop, bcloned);
458
459 let trans = BackendTranspose::ConjTrans;
460 let tcloned = trans;
461 assert_eq!(trans, tcloned);
462 }
463}