oxicuda_blas/types.rs
1//! Common types shared across BLAS operations.
2//!
3//! This module defines enumerations for controlling BLAS behaviour,
4//! such as math precision mode and scalar pointer location, as well as
5//! the [`GpuFloat`] trait that abstracts over GPU-compatible floating-point
6//! types, [`VectorDesc`] for describing strided vector layouts, and
7//! [`MatrixDesc`] / [`MatrixDescMut`] for describing dense matrices on the
8//! device.
9
10use std::marker::PhantomData;
11
12use oxicuda_driver::ffi::CUdeviceptr;
13use oxicuda_memory::DeviceBuffer;
14use oxicuda_ptx::ir::PtxType;
15
16use crate::error::{BlasError, BlasResult};
17
18// ---------------------------------------------------------------------------
19// MathMode — precision / throughput trade-off
20// ---------------------------------------------------------------------------
21
22/// Controls whether Tensor-Core (reduced-precision) paths are used.
23///
24/// When set to [`TensorCore`](Self::TensorCore), GEMM and similar routines
25/// may use FP16/BF16/TF32 Tensor-Core instructions for improved throughput,
26/// at the cost of slightly reduced numerical precision.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum MathMode {
29 /// Use only standard FMA pipelines (FP32/FP64). This is the default.
30 Default,
31 /// Allow Tensor-Core instructions when the device supports them.
32 TensorCore,
33 /// Use lowest precision available for maximum throughput.
34 MaxPerformance,
35}
36
37// ---------------------------------------------------------------------------
38// PointerMode — where scalar arguments reside
39// ---------------------------------------------------------------------------
40
41/// Specifies where scalar arguments (alpha, beta) reside.
42///
43/// Most users should leave this at [`Host`](Self::Host). Switching to
44/// [`Device`](Self::Device) avoids a host-device synchronisation barrier
45/// when scalars are already computed on the GPU (e.g. in a training loop).
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum PointerMode {
48 /// Scalars are passed from host memory (default).
49 Host,
50 /// Scalars reside in device memory.
51 Device,
52}
53
54// ---------------------------------------------------------------------------
55// Layout — memory ordering of a dense matrix
56// ---------------------------------------------------------------------------
57
58/// Memory layout of a dense matrix.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
60pub enum Layout {
61 /// Row-major (C-style): `A[i][j] = ptr[i * lda + j]`.
62 RowMajor,
63 /// Column-major (Fortran-style): `A[i][j] = ptr[j * lda + i]`.
64 ColMajor,
65}
66
67// ---------------------------------------------------------------------------
68// Transpose — matrix transposition mode
69// ---------------------------------------------------------------------------
70
71/// Transpose mode for a matrix operand.
72///
73/// This mirrors the classic BLAS `TRANSA` / `TRANSB` parameter and determines
74/// whether a matrix is used as-is, transposed, or conjugate-transposed.
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum Transpose {
77 /// Use the matrix as-is (no transposition).
78 NoTrans,
79 /// Use the transpose of the matrix (A^T).
80 Trans,
81 /// Use the conjugate-transpose (A^H). For real types this is identical
82 /// to [`Trans`](Self::Trans).
83 ConjTrans,
84}
85
86// ---------------------------------------------------------------------------
87// FillMode — upper / lower triangle selection
88// ---------------------------------------------------------------------------
89
90/// Specifies which triangle of a symmetric or triangular matrix is stored.
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
92pub enum FillMode {
93 /// The upper triangle is stored / referenced.
94 Upper,
95 /// The lower triangle is stored / referenced.
96 Lower,
97 /// Full matrix (both triangles).
98 Full,
99}
100
101// ---------------------------------------------------------------------------
102// Side — left/right operand position
103// ---------------------------------------------------------------------------
104
105/// Specifies on which side a special matrix (symmetric / triangular) appears
106/// in a two-operand BLAS-3 operation.
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
108pub enum Side {
109 /// The special matrix is on the left: `op(A) * B`.
110 Left,
111 /// The special matrix is on the right: `B * op(A)`.
112 Right,
113}
114
115// ---------------------------------------------------------------------------
116// DiagType — unit / non-unit diagonal
117// ---------------------------------------------------------------------------
118
119/// Specifies whether a triangular matrix has an implicit unit diagonal.
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
121pub enum DiagType {
122 /// The diagonal entries are stored explicitly and used as-is.
123 NonUnit,
124 /// The diagonal is implicitly all ones (unit diagonal).
125 Unit,
126}
127
128// ---------------------------------------------------------------------------
129// GpuFloat — trait for GPU-compatible floating-point types
130// ---------------------------------------------------------------------------
131
132/// Trait for floating-point types that can be used in GPU BLAS kernels.
133///
134/// Provides the mapping between Rust types and PTX register types, element
135/// sizes, and bit-level representation for passing scalars as kernel parameters.
136///
137/// The trait bound is deliberately minimal so that half-precision types
138/// (`half::f16`, `half::bf16`) and FP8 types can implement it.
139pub trait GpuFloat: Copy + Send + Sync + 'static + std::fmt::Debug + PartialOrd {
140 /// The PTX register type used for this precision (e.g. `PtxType::F32`).
141 const PTX_TYPE: PtxType;
142
143 /// Size of one element in bytes.
144 const SIZE: usize;
145
146 /// A short name used in generated kernel names (e.g. `"f32"`, `"f64"`).
147 const NAME: &'static str;
148
149 /// Whether this type is eligible for Tensor-Core acceleration.
150 const TENSOR_CORE_ELIGIBLE: bool;
151
152 /// The accumulator type used when this type feeds a Tensor-Core MMA.
153 ///
154 /// For f32/f64 this is `Self`; for f16/bf16/FP8 it is typically `f32`.
155 type Accumulator: GpuFloat;
156
157 /// Converts the scalar to its raw bit representation as a `u64`.
158 ///
159 /// For `f32`, the upper 32 bits are zero. For `f64`, all 64 bits are used.
160 /// This is how scalar constants are passed to PTX kernels.
161 fn to_bits_u64(self) -> u64;
162
163 /// Reconstructs a value from its raw bit representation stored in a `u64`.
164 fn from_bits_u64(bits: u64) -> Self;
165
166 /// The zero value for this type (additive identity).
167 fn gpu_zero() -> Self;
168
169 /// The one value for this type (multiplicative identity).
170 fn gpu_one() -> Self;
171
172 /// Size of one element in bytes, as `u32`.
173 ///
174 /// Convenience helper for PTX code-generation where `u32` strides are
175 /// expected (e.g. `byte_offset_addr`).
176 #[inline]
177 fn size_u32() -> u32 {
178 Self::SIZE as u32
179 }
180}
181
182// -- f32 impl -----------------------------------------------------------------
183
184impl GpuFloat for f32 {
185 const PTX_TYPE: PtxType = PtxType::F32;
186 const SIZE: usize = 4;
187 const NAME: &'static str = "f32";
188 const TENSOR_CORE_ELIGIBLE: bool = true;
189 type Accumulator = f32;
190
191 #[inline]
192 fn to_bits_u64(self) -> u64 {
193 u64::from(self.to_bits())
194 }
195
196 #[inline]
197 fn from_bits_u64(bits: u64) -> Self {
198 f32::from_bits(bits as u32)
199 }
200
201 #[inline]
202 fn gpu_zero() -> Self {
203 0.0
204 }
205
206 #[inline]
207 fn gpu_one() -> Self {
208 1.0
209 }
210}
211
212// -- f64 impl -----------------------------------------------------------------
213
214impl GpuFloat for f64 {
215 const PTX_TYPE: PtxType = PtxType::F64;
216 const SIZE: usize = 8;
217 const NAME: &'static str = "f64";
218 const TENSOR_CORE_ELIGIBLE: bool = true;
219 type Accumulator = f64;
220
221 #[inline]
222 fn to_bits_u64(self) -> u64 {
223 self.to_bits()
224 }
225
226 #[inline]
227 fn from_bits_u64(bits: u64) -> Self {
228 f64::from_bits(bits)
229 }
230
231 #[inline]
232 fn gpu_zero() -> Self {
233 0.0
234 }
235
236 #[inline]
237 fn gpu_one() -> Self {
238 1.0
239 }
240}
241
242// -- half::f16 impl (feature-gated) ------------------------------------------
243
244#[cfg(feature = "f16")]
245impl GpuFloat for half::f16 {
246 const PTX_TYPE: PtxType = PtxType::F16;
247 const SIZE: usize = 2;
248 const NAME: &'static str = "f16";
249 const TENSOR_CORE_ELIGIBLE: bool = true;
250 type Accumulator = f32;
251
252 #[inline]
253 fn to_bits_u64(self) -> u64 {
254 u64::from(self.to_bits())
255 }
256
257 #[inline]
258 fn from_bits_u64(bits: u64) -> Self {
259 half::f16::from_bits(bits as u16)
260 }
261
262 #[inline]
263 fn gpu_zero() -> Self {
264 half::f16::ZERO
265 }
266
267 #[inline]
268 fn gpu_one() -> Self {
269 half::f16::ONE
270 }
271}
272
273// -- half::bf16 impl (feature-gated) -----------------------------------------
274
275#[cfg(feature = "f16")]
276impl GpuFloat for half::bf16 {
277 const PTX_TYPE: PtxType = PtxType::BF16;
278 const SIZE: usize = 2;
279 const NAME: &'static str = "bf16";
280 const TENSOR_CORE_ELIGIBLE: bool = true;
281 type Accumulator = f32;
282
283 #[inline]
284 fn to_bits_u64(self) -> u64 {
285 u64::from(self.to_bits())
286 }
287
288 #[inline]
289 fn from_bits_u64(bits: u64) -> Self {
290 half::bf16::from_bits(bits as u16)
291 }
292
293 #[inline]
294 fn gpu_zero() -> Self {
295 half::bf16::ZERO
296 }
297
298 #[inline]
299 fn gpu_one() -> Self {
300 half::bf16::ONE
301 }
302}
303
304// ---------------------------------------------------------------------------
305// FP8 types — Hopper+ (SM90) reduced-precision formats
306// ---------------------------------------------------------------------------
307
308/// FP8 E4M3 format (4-bit exponent, 3-bit mantissa).
309///
310/// Used primarily for inference on Hopper+ GPUs. The dynamic range is smaller
311/// than E5M2 but the extra mantissa bit gives better precision for weights
312/// and activations that stay within range.
313#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
314#[repr(transparent)]
315pub struct E4M3(pub u8);
316
317/// FP8 E5M2 format (5-bit exponent, 2-bit mantissa).
318///
319/// Used primarily for training gradients on Hopper+ GPUs. The wider exponent
320/// range accommodates the larger dynamic range of gradient values.
321#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
322#[repr(transparent)]
323pub struct E5M2(pub u8);
324
325// SAFETY: E4M3 and E5M2 are `#[repr(transparent)]` wrappers around `u8`,
326// which is trivially Send + Sync.
327unsafe impl Send for E4M3 {}
328unsafe impl Sync for E4M3 {}
329unsafe impl Send for E5M2 {}
330unsafe impl Sync for E5M2 {}
331
332impl GpuFloat for E4M3 {
333 const PTX_TYPE: PtxType = PtxType::E4M3;
334 const SIZE: usize = 1;
335 const NAME: &'static str = "e4m3";
336 const TENSOR_CORE_ELIGIBLE: bool = true;
337 type Accumulator = f32;
338
339 #[inline]
340 fn to_bits_u64(self) -> u64 {
341 u64::from(self.0)
342 }
343
344 #[inline]
345 fn from_bits_u64(bits: u64) -> Self {
346 Self(bits as u8)
347 }
348
349 #[inline]
350 fn gpu_zero() -> Self {
351 Self(0x00)
352 }
353
354 #[inline]
355 fn gpu_one() -> Self {
356 Self(0x38)
357 }
358}
359
360impl GpuFloat for E5M2 {
361 const PTX_TYPE: PtxType = PtxType::E5M2;
362 const SIZE: usize = 1;
363 const NAME: &'static str = "e5m2";
364 const TENSOR_CORE_ELIGIBLE: bool = true;
365 type Accumulator = f32;
366
367 #[inline]
368 fn to_bits_u64(self) -> u64 {
369 u64::from(self.0)
370 }
371
372 #[inline]
373 fn from_bits_u64(bits: u64) -> Self {
374 Self(bits as u8)
375 }
376
377 #[inline]
378 fn gpu_zero() -> Self {
379 Self(0x00)
380 }
381
382 #[inline]
383 fn gpu_one() -> Self {
384 Self(0x3C)
385 }
386}
387
388// ---------------------------------------------------------------------------
389// VectorDesc — describes a strided vector on the device
390// ---------------------------------------------------------------------------
391
392/// Describes the layout of a vector stored in device memory.
393///
394/// BLAS Level 1 routines work on vectors that may be stored with a stride
395/// (increment) between consecutive logical elements. This struct captures
396/// the logical length, the stride, and the required buffer capacity.
397#[derive(Debug, Clone, Copy)]
398pub struct VectorDesc {
399 /// Number of logical elements.
400 pub n: u32,
401 /// Stride (increment) between consecutive elements. Must be positive.
402 pub inc: u32,
403}
404
405impl VectorDesc {
406 /// Creates a new vector descriptor.
407 ///
408 /// # Arguments
409 ///
410 /// * `n` — number of logical elements.
411 /// * `inc` — stride between elements (absolute value of the user-supplied
412 /// increment). Must be at least 1.
413 #[must_use]
414 pub fn new(n: u32, inc: u32) -> Self {
415 Self { n, inc }
416 }
417
418 /// Returns the minimum number of elements the backing buffer must hold.
419 ///
420 /// For a vector of `n` elements with stride `inc`, the last element is at
421 /// index `(n - 1) * inc`, so the buffer needs at least `1 + (n-1) * inc`
422 /// elements.
423 #[must_use]
424 pub fn required_elements(&self) -> usize {
425 if self.n == 0 {
426 return 0;
427 }
428 1 + (self.n as usize - 1) * self.inc as usize
429 }
430}
431
432// ---------------------------------------------------------------------------
433// MatrixDesc — describes a dense matrix on the device (immutable view)
434// ---------------------------------------------------------------------------
435
436/// Describes a matrix stored in device memory.
437///
438/// This is an immutable (read-only) view. For an output matrix that will be
439/// written to, use [`MatrixDescMut`].
440///
441/// All fields are `Copy`-sized, so `MatrixDesc` itself is `Copy`.
442#[derive(Debug, Clone, Copy)]
443pub struct MatrixDesc<T: GpuFloat> {
444 /// Device pointer to the matrix data.
445 pub ptr: CUdeviceptr,
446 /// Number of rows.
447 pub rows: u32,
448 /// Number of columns.
449 pub cols: u32,
450 /// Leading dimension (stride between rows/columns depending on layout).
451 pub ld: u32,
452 /// Memory layout.
453 pub layout: Layout,
454 _phantom: PhantomData<T>,
455}
456
457impl<T: GpuFloat> MatrixDesc<T> {
458 /// Create a matrix descriptor from a [`DeviceBuffer`].
459 ///
460 /// Returns an error if the buffer is too small for the requested dimensions.
461 pub fn from_buffer(
462 buf: &DeviceBuffer<T>,
463 rows: u32,
464 cols: u32,
465 layout: Layout,
466 ) -> BlasResult<Self> {
467 let required = rows as usize * cols as usize;
468 if buf.len() < required {
469 return Err(BlasError::BufferTooSmall {
470 expected: required,
471 actual: buf.len(),
472 });
473 }
474 let ld = match layout {
475 Layout::RowMajor => cols,
476 Layout::ColMajor => rows,
477 };
478 Ok(Self {
479 ptr: buf.as_device_ptr(),
480 rows,
481 cols,
482 ld,
483 layout,
484 _phantom: PhantomData,
485 })
486 }
487
488 /// Create with a raw device pointer (no size validation).
489 pub fn from_raw(ptr: CUdeviceptr, rows: u32, cols: u32, ld: u32, layout: Layout) -> Self {
490 Self {
491 ptr,
492 rows,
493 cols,
494 ld,
495 layout,
496 _phantom: PhantomData,
497 }
498 }
499
500 /// Override the leading dimension.
501 #[must_use]
502 pub fn with_ld(mut self, ld: u32) -> Self {
503 self.ld = ld;
504 self
505 }
506
507 /// Total number of elements.
508 #[must_use]
509 pub fn numel(&self) -> usize {
510 self.rows as usize * self.cols as usize
511 }
512
513 /// Storage bytes (full stride, including padding from leading dimension).
514 #[must_use]
515 pub fn storage_bytes(&self) -> usize {
516 let major = match self.layout {
517 Layout::RowMajor => self.rows,
518 Layout::ColMajor => self.cols,
519 };
520 major as usize * self.ld as usize * T::SIZE
521 }
522
523 /// Effective dimensions after transpose.
524 #[must_use]
525 pub fn effective_dims(&self, trans: Transpose) -> (u32, u32) {
526 match trans {
527 Transpose::NoTrans => (self.rows, self.cols),
528 Transpose::Trans | Transpose::ConjTrans => (self.cols, self.rows),
529 }
530 }
531}
532
533// ---------------------------------------------------------------------------
534// MatrixDescMut — mutable matrix descriptor
535// ---------------------------------------------------------------------------
536
537/// Describes a mutable (output) matrix stored in device memory.
538///
539/// Identical to [`MatrixDesc`] but signals intent to write. This distinction
540/// prevents accidentally passing an input buffer where an output is expected.
541#[derive(Debug, Clone, Copy)]
542pub struct MatrixDescMut<T: GpuFloat> {
543 /// Device pointer to the matrix data.
544 pub ptr: CUdeviceptr,
545 /// Number of rows.
546 pub rows: u32,
547 /// Number of columns.
548 pub cols: u32,
549 /// Leading dimension (stride between rows/columns depending on layout).
550 pub ld: u32,
551 /// Memory layout.
552 pub layout: Layout,
553 _phantom: PhantomData<T>,
554}
555
556impl<T: GpuFloat> MatrixDescMut<T> {
557 /// Create a mutable matrix descriptor from a [`DeviceBuffer`].
558 ///
559 /// Returns an error if the buffer is too small for the requested dimensions.
560 pub fn from_buffer(
561 buf: &mut DeviceBuffer<T>,
562 rows: u32,
563 cols: u32,
564 layout: Layout,
565 ) -> BlasResult<Self> {
566 let required = rows as usize * cols as usize;
567 if buf.len() < required {
568 return Err(BlasError::BufferTooSmall {
569 expected: required,
570 actual: buf.len(),
571 });
572 }
573 let ld = match layout {
574 Layout::RowMajor => cols,
575 Layout::ColMajor => rows,
576 };
577 Ok(Self {
578 ptr: buf.as_device_ptr(),
579 rows,
580 cols,
581 ld,
582 layout,
583 _phantom: PhantomData,
584 })
585 }
586
587 /// Create with a raw device pointer (no size validation).
588 pub fn from_raw(ptr: CUdeviceptr, rows: u32, cols: u32, ld: u32, layout: Layout) -> Self {
589 Self {
590 ptr,
591 rows,
592 cols,
593 ld,
594 layout,
595 _phantom: PhantomData,
596 }
597 }
598
599 /// Override the leading dimension.
600 #[must_use]
601 pub fn with_ld(mut self, ld: u32) -> Self {
602 self.ld = ld;
603 self
604 }
605
606 /// Total number of elements.
607 #[must_use]
608 pub fn numel(&self) -> usize {
609 self.rows as usize * self.cols as usize
610 }
611
612 /// Storage bytes (full stride, including padding from leading dimension).
613 #[must_use]
614 pub fn storage_bytes(&self) -> usize {
615 let major = match self.layout {
616 Layout::RowMajor => self.rows,
617 Layout::ColMajor => self.cols,
618 };
619 major as usize * self.ld as usize * T::SIZE
620 }
621
622 /// Effective dimensions after transpose.
623 #[must_use]
624 pub fn effective_dims(&self, trans: Transpose) -> (u32, u32) {
625 match trans {
626 Transpose::NoTrans => (self.rows, self.cols),
627 Transpose::Trans | Transpose::ConjTrans => (self.cols, self.rows),
628 }
629 }
630
631 /// Borrow as an immutable [`MatrixDesc`].
632 #[must_use]
633 pub fn as_immutable(&self) -> MatrixDesc<T> {
634 MatrixDesc {
635 ptr: self.ptr,
636 rows: self.rows,
637 cols: self.cols,
638 ld: self.ld,
639 layout: self.layout,
640 _phantom: PhantomData,
641 }
642 }
643}