Skip to main content

oxicuda_blas/
types.rs

1//! Common types shared across BLAS operations.
2//!
3//! This module defines enumerations for controlling BLAS behaviour,
4//! such as math precision mode and scalar pointer location, as well as
5//! the [`GpuFloat`] trait that abstracts over GPU-compatible floating-point
6//! types, [`VectorDesc`] for describing strided vector layouts, and
7//! [`MatrixDesc`] / [`MatrixDescMut`] for describing dense matrices on the
8//! device.
9
10use std::marker::PhantomData;
11
12use oxicuda_driver::ffi::CUdeviceptr;
13use oxicuda_memory::DeviceBuffer;
14use oxicuda_ptx::ir::PtxType;
15
16use crate::error::{BlasError, BlasResult};
17
18// ---------------------------------------------------------------------------
19// MathMode — precision / throughput trade-off
20// ---------------------------------------------------------------------------
21
22/// Controls whether Tensor-Core (reduced-precision) paths are used.
23///
24/// When set to [`TensorCore`](Self::TensorCore), GEMM and similar routines
25/// may use FP16/BF16/TF32 Tensor-Core instructions for improved throughput,
26/// at the cost of slightly reduced numerical precision.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum MathMode {
29    /// Use only standard FMA pipelines (FP32/FP64). This is the default.
30    Default,
31    /// Allow Tensor-Core instructions when the device supports them.
32    TensorCore,
33    /// Use lowest precision available for maximum throughput.
34    MaxPerformance,
35}
36
37// ---------------------------------------------------------------------------
38// PointerMode — where scalar arguments reside
39// ---------------------------------------------------------------------------
40
41/// Specifies where scalar arguments (alpha, beta) reside.
42///
43/// Most users should leave this at [`Host`](Self::Host). Switching to
44/// [`Device`](Self::Device) avoids a host-device synchronisation barrier
45/// when scalars are already computed on the GPU (e.g. in a training loop).
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
47pub enum PointerMode {
48    /// Scalars are passed from host memory (default).
49    Host,
50    /// Scalars reside in device memory.
51    Device,
52}
53
54// ---------------------------------------------------------------------------
55// Layout — memory ordering of a dense matrix
56// ---------------------------------------------------------------------------
57
58/// Memory layout of a dense matrix.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
60pub enum Layout {
61    /// Row-major (C-style): `A[i][j] = ptr[i * lda + j]`.
62    RowMajor,
63    /// Column-major (Fortran-style): `A[i][j] = ptr[j * lda + i]`.
64    ColMajor,
65}
66
67// ---------------------------------------------------------------------------
68// Transpose — matrix transposition mode
69// ---------------------------------------------------------------------------
70
71/// Transpose mode for a matrix operand.
72///
73/// This mirrors the classic BLAS `TRANSA` / `TRANSB` parameter and determines
74/// whether a matrix is used as-is, transposed, or conjugate-transposed.
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum Transpose {
77    /// Use the matrix as-is (no transposition).
78    NoTrans,
79    /// Use the transpose of the matrix (A^T).
80    Trans,
81    /// Use the conjugate-transpose (A^H). For real types this is identical
82    /// to [`Trans`](Self::Trans).
83    ConjTrans,
84}
85
86// ---------------------------------------------------------------------------
87// FillMode — upper / lower triangle selection
88// ---------------------------------------------------------------------------
89
90/// Specifies which triangle of a symmetric or triangular matrix is stored.
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
92pub enum FillMode {
93    /// The upper triangle is stored / referenced.
94    Upper,
95    /// The lower triangle is stored / referenced.
96    Lower,
97    /// Full matrix (both triangles).
98    Full,
99}
100
101// ---------------------------------------------------------------------------
102// Side — left/right operand position
103// ---------------------------------------------------------------------------
104
105/// Specifies on which side a special matrix (symmetric / triangular) appears
106/// in a two-operand BLAS-3 operation.
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
108pub enum Side {
109    /// The special matrix is on the left: `op(A) * B`.
110    Left,
111    /// The special matrix is on the right: `B * op(A)`.
112    Right,
113}
114
115// ---------------------------------------------------------------------------
116// DiagType — unit / non-unit diagonal
117// ---------------------------------------------------------------------------
118
119/// Specifies whether a triangular matrix has an implicit unit diagonal.
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
121pub enum DiagType {
122    /// The diagonal entries are stored explicitly and used as-is.
123    NonUnit,
124    /// The diagonal is implicitly all ones (unit diagonal).
125    Unit,
126}
127
128// ---------------------------------------------------------------------------
129// GpuFloat — trait for GPU-compatible floating-point types
130// ---------------------------------------------------------------------------
131
132/// Trait for floating-point types that can be used in GPU BLAS kernels.
133///
134/// Provides the mapping between Rust types and PTX register types, element
135/// sizes, and bit-level representation for passing scalars as kernel parameters.
136///
137/// The trait bound is deliberately minimal so that half-precision types
138/// (`half::f16`, `half::bf16`) and FP8 types can implement it.
139pub trait GpuFloat: Copy + Send + Sync + 'static + std::fmt::Debug + PartialOrd {
140    /// The PTX register type used for this precision (e.g. `PtxType::F32`).
141    const PTX_TYPE: PtxType;
142
143    /// Size of one element in bytes.
144    const SIZE: usize;
145
146    /// A short name used in generated kernel names (e.g. `"f32"`, `"f64"`).
147    const NAME: &'static str;
148
149    /// Whether this type is eligible for Tensor-Core acceleration.
150    const TENSOR_CORE_ELIGIBLE: bool;
151
152    /// The accumulator type used when this type feeds a Tensor-Core MMA.
153    ///
154    /// For f32/f64 this is `Self`; for f16/bf16/FP8 it is typically `f32`.
155    type Accumulator: GpuFloat;
156
157    /// Converts the scalar to its raw bit representation as a `u64`.
158    ///
159    /// For `f32`, the upper 32 bits are zero. For `f64`, all 64 bits are used.
160    /// This is how scalar constants are passed to PTX kernels.
161    fn to_bits_u64(self) -> u64;
162
163    /// Reconstructs a value from its raw bit representation stored in a `u64`.
164    fn from_bits_u64(bits: u64) -> Self;
165
166    /// The zero value for this type (additive identity).
167    fn gpu_zero() -> Self;
168
169    /// The one value for this type (multiplicative identity).
170    fn gpu_one() -> Self;
171
172    /// Size of one element in bytes, as `u32`.
173    ///
174    /// Convenience helper for PTX code-generation where `u32` strides are
175    /// expected (e.g. `byte_offset_addr`).
176    #[inline]
177    fn size_u32() -> u32 {
178        Self::SIZE as u32
179    }
180}
181
182// -- f32 impl -----------------------------------------------------------------
183
184impl GpuFloat for f32 {
185    const PTX_TYPE: PtxType = PtxType::F32;
186    const SIZE: usize = 4;
187    const NAME: &'static str = "f32";
188    const TENSOR_CORE_ELIGIBLE: bool = true;
189    type Accumulator = f32;
190
191    #[inline]
192    fn to_bits_u64(self) -> u64 {
193        u64::from(self.to_bits())
194    }
195
196    #[inline]
197    fn from_bits_u64(bits: u64) -> Self {
198        f32::from_bits(bits as u32)
199    }
200
201    #[inline]
202    fn gpu_zero() -> Self {
203        0.0
204    }
205
206    #[inline]
207    fn gpu_one() -> Self {
208        1.0
209    }
210}
211
212// -- f64 impl -----------------------------------------------------------------
213
214impl GpuFloat for f64 {
215    const PTX_TYPE: PtxType = PtxType::F64;
216    const SIZE: usize = 8;
217    const NAME: &'static str = "f64";
218    const TENSOR_CORE_ELIGIBLE: bool = true;
219    type Accumulator = f64;
220
221    #[inline]
222    fn to_bits_u64(self) -> u64 {
223        self.to_bits()
224    }
225
226    #[inline]
227    fn from_bits_u64(bits: u64) -> Self {
228        f64::from_bits(bits)
229    }
230
231    #[inline]
232    fn gpu_zero() -> Self {
233        0.0
234    }
235
236    #[inline]
237    fn gpu_one() -> Self {
238        1.0
239    }
240}
241
242// -- half::f16 impl (feature-gated) ------------------------------------------
243
244#[cfg(feature = "f16")]
245impl GpuFloat for half::f16 {
246    const PTX_TYPE: PtxType = PtxType::F16;
247    const SIZE: usize = 2;
248    const NAME: &'static str = "f16";
249    const TENSOR_CORE_ELIGIBLE: bool = true;
250    type Accumulator = f32;
251
252    #[inline]
253    fn to_bits_u64(self) -> u64 {
254        u64::from(self.to_bits())
255    }
256
257    #[inline]
258    fn from_bits_u64(bits: u64) -> Self {
259        half::f16::from_bits(bits as u16)
260    }
261
262    #[inline]
263    fn gpu_zero() -> Self {
264        half::f16::ZERO
265    }
266
267    #[inline]
268    fn gpu_one() -> Self {
269        half::f16::ONE
270    }
271}
272
273// -- half::bf16 impl (feature-gated) -----------------------------------------
274
275#[cfg(feature = "f16")]
276impl GpuFloat for half::bf16 {
277    const PTX_TYPE: PtxType = PtxType::BF16;
278    const SIZE: usize = 2;
279    const NAME: &'static str = "bf16";
280    const TENSOR_CORE_ELIGIBLE: bool = true;
281    type Accumulator = f32;
282
283    #[inline]
284    fn to_bits_u64(self) -> u64 {
285        u64::from(self.to_bits())
286    }
287
288    #[inline]
289    fn from_bits_u64(bits: u64) -> Self {
290        half::bf16::from_bits(bits as u16)
291    }
292
293    #[inline]
294    fn gpu_zero() -> Self {
295        half::bf16::ZERO
296    }
297
298    #[inline]
299    fn gpu_one() -> Self {
300        half::bf16::ONE
301    }
302}
303
304// ---------------------------------------------------------------------------
305// FP8 types — Hopper+ (SM90) reduced-precision formats
306// ---------------------------------------------------------------------------
307
308/// FP8 E4M3 format (4-bit exponent, 3-bit mantissa).
309///
310/// Used primarily for inference on Hopper+ GPUs. The dynamic range is smaller
311/// than E5M2 but the extra mantissa bit gives better precision for weights
312/// and activations that stay within range.
313#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
314#[repr(transparent)]
315pub struct E4M3(pub u8);
316
317/// FP8 E5M2 format (5-bit exponent, 2-bit mantissa).
318///
319/// Used primarily for training gradients on Hopper+ GPUs. The wider exponent
320/// range accommodates the larger dynamic range of gradient values.
321#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
322#[repr(transparent)]
323pub struct E5M2(pub u8);
324
325// SAFETY: E4M3 and E5M2 are `#[repr(transparent)]` wrappers around `u8`,
326// which is trivially Send + Sync.
327unsafe impl Send for E4M3 {}
328unsafe impl Sync for E4M3 {}
329unsafe impl Send for E5M2 {}
330unsafe impl Sync for E5M2 {}
331
332impl GpuFloat for E4M3 {
333    const PTX_TYPE: PtxType = PtxType::E4M3;
334    const SIZE: usize = 1;
335    const NAME: &'static str = "e4m3";
336    const TENSOR_CORE_ELIGIBLE: bool = true;
337    type Accumulator = f32;
338
339    #[inline]
340    fn to_bits_u64(self) -> u64 {
341        u64::from(self.0)
342    }
343
344    #[inline]
345    fn from_bits_u64(bits: u64) -> Self {
346        Self(bits as u8)
347    }
348
349    #[inline]
350    fn gpu_zero() -> Self {
351        Self(0x00)
352    }
353
354    #[inline]
355    fn gpu_one() -> Self {
356        Self(0x38)
357    }
358}
359
360impl GpuFloat for E5M2 {
361    const PTX_TYPE: PtxType = PtxType::E5M2;
362    const SIZE: usize = 1;
363    const NAME: &'static str = "e5m2";
364    const TENSOR_CORE_ELIGIBLE: bool = true;
365    type Accumulator = f32;
366
367    #[inline]
368    fn to_bits_u64(self) -> u64 {
369        u64::from(self.0)
370    }
371
372    #[inline]
373    fn from_bits_u64(bits: u64) -> Self {
374        Self(bits as u8)
375    }
376
377    #[inline]
378    fn gpu_zero() -> Self {
379        Self(0x00)
380    }
381
382    #[inline]
383    fn gpu_one() -> Self {
384        Self(0x3C)
385    }
386}
387
388// ---------------------------------------------------------------------------
389// VectorDesc — describes a strided vector on the device
390// ---------------------------------------------------------------------------
391
392/// Describes the layout of a vector stored in device memory.
393///
394/// BLAS Level 1 routines work on vectors that may be stored with a stride
395/// (increment) between consecutive logical elements. This struct captures
396/// the logical length, the stride, and the required buffer capacity.
397#[derive(Debug, Clone, Copy)]
398pub struct VectorDesc {
399    /// Number of logical elements.
400    pub n: u32,
401    /// Stride (increment) between consecutive elements. Must be positive.
402    pub inc: u32,
403}
404
405impl VectorDesc {
406    /// Creates a new vector descriptor.
407    ///
408    /// # Arguments
409    ///
410    /// * `n` — number of logical elements.
411    /// * `inc` — stride between elements (absolute value of the user-supplied
412    ///   increment). Must be at least 1.
413    #[must_use]
414    pub fn new(n: u32, inc: u32) -> Self {
415        Self { n, inc }
416    }
417
418    /// Returns the minimum number of elements the backing buffer must hold.
419    ///
420    /// For a vector of `n` elements with stride `inc`, the last element is at
421    /// index `(n - 1) * inc`, so the buffer needs at least `1 + (n-1) * inc`
422    /// elements.
423    #[must_use]
424    pub fn required_elements(&self) -> usize {
425        if self.n == 0 {
426            return 0;
427        }
428        1 + (self.n as usize - 1) * self.inc as usize
429    }
430}
431
432// ---------------------------------------------------------------------------
433// MatrixDesc — describes a dense matrix on the device (immutable view)
434// ---------------------------------------------------------------------------
435
436/// Describes a matrix stored in device memory.
437///
438/// This is an immutable (read-only) view. For an output matrix that will be
439/// written to, use [`MatrixDescMut`].
440///
441/// All fields are `Copy`-sized, so `MatrixDesc` itself is `Copy`.
442#[derive(Debug, Clone, Copy)]
443pub struct MatrixDesc<T: GpuFloat> {
444    /// Device pointer to the matrix data.
445    pub ptr: CUdeviceptr,
446    /// Number of rows.
447    pub rows: u32,
448    /// Number of columns.
449    pub cols: u32,
450    /// Leading dimension (stride between rows/columns depending on layout).
451    pub ld: u32,
452    /// Memory layout.
453    pub layout: Layout,
454    _phantom: PhantomData<T>,
455}
456
457impl<T: GpuFloat> MatrixDesc<T> {
458    /// Create a matrix descriptor from a [`DeviceBuffer`].
459    ///
460    /// Returns an error if the buffer is too small for the requested dimensions.
461    pub fn from_buffer(
462        buf: &DeviceBuffer<T>,
463        rows: u32,
464        cols: u32,
465        layout: Layout,
466    ) -> BlasResult<Self> {
467        let required = rows as usize * cols as usize;
468        if buf.len() < required {
469            return Err(BlasError::BufferTooSmall {
470                expected: required,
471                actual: buf.len(),
472            });
473        }
474        let ld = match layout {
475            Layout::RowMajor => cols,
476            Layout::ColMajor => rows,
477        };
478        Ok(Self {
479            ptr: buf.as_device_ptr(),
480            rows,
481            cols,
482            ld,
483            layout,
484            _phantom: PhantomData,
485        })
486    }
487
488    /// Create with a raw device pointer (no size validation).
489    pub fn from_raw(ptr: CUdeviceptr, rows: u32, cols: u32, ld: u32, layout: Layout) -> Self {
490        Self {
491            ptr,
492            rows,
493            cols,
494            ld,
495            layout,
496            _phantom: PhantomData,
497        }
498    }
499
500    /// Override the leading dimension.
501    #[must_use]
502    pub fn with_ld(mut self, ld: u32) -> Self {
503        self.ld = ld;
504        self
505    }
506
507    /// Total number of elements.
508    #[must_use]
509    pub fn numel(&self) -> usize {
510        self.rows as usize * self.cols as usize
511    }
512
513    /// Storage bytes (full stride, including padding from leading dimension).
514    #[must_use]
515    pub fn storage_bytes(&self) -> usize {
516        let major = match self.layout {
517            Layout::RowMajor => self.rows,
518            Layout::ColMajor => self.cols,
519        };
520        major as usize * self.ld as usize * T::SIZE
521    }
522
523    /// Effective dimensions after transpose.
524    #[must_use]
525    pub fn effective_dims(&self, trans: Transpose) -> (u32, u32) {
526        match trans {
527            Transpose::NoTrans => (self.rows, self.cols),
528            Transpose::Trans | Transpose::ConjTrans => (self.cols, self.rows),
529        }
530    }
531}
532
533// ---------------------------------------------------------------------------
534// MatrixDescMut — mutable matrix descriptor
535// ---------------------------------------------------------------------------
536
537/// Describes a mutable (output) matrix stored in device memory.
538///
539/// Identical to [`MatrixDesc`] but signals intent to write. This distinction
540/// prevents accidentally passing an input buffer where an output is expected.
541#[derive(Debug, Clone, Copy)]
542pub struct MatrixDescMut<T: GpuFloat> {
543    /// Device pointer to the matrix data.
544    pub ptr: CUdeviceptr,
545    /// Number of rows.
546    pub rows: u32,
547    /// Number of columns.
548    pub cols: u32,
549    /// Leading dimension (stride between rows/columns depending on layout).
550    pub ld: u32,
551    /// Memory layout.
552    pub layout: Layout,
553    _phantom: PhantomData<T>,
554}
555
556impl<T: GpuFloat> MatrixDescMut<T> {
557    /// Create a mutable matrix descriptor from a [`DeviceBuffer`].
558    ///
559    /// Returns an error if the buffer is too small for the requested dimensions.
560    pub fn from_buffer(
561        buf: &mut DeviceBuffer<T>,
562        rows: u32,
563        cols: u32,
564        layout: Layout,
565    ) -> BlasResult<Self> {
566        let required = rows as usize * cols as usize;
567        if buf.len() < required {
568            return Err(BlasError::BufferTooSmall {
569                expected: required,
570                actual: buf.len(),
571            });
572        }
573        let ld = match layout {
574            Layout::RowMajor => cols,
575            Layout::ColMajor => rows,
576        };
577        Ok(Self {
578            ptr: buf.as_device_ptr(),
579            rows,
580            cols,
581            ld,
582            layout,
583            _phantom: PhantomData,
584        })
585    }
586
587    /// Create with a raw device pointer (no size validation).
588    pub fn from_raw(ptr: CUdeviceptr, rows: u32, cols: u32, ld: u32, layout: Layout) -> Self {
589        Self {
590            ptr,
591            rows,
592            cols,
593            ld,
594            layout,
595            _phantom: PhantomData,
596        }
597    }
598
599    /// Override the leading dimension.
600    #[must_use]
601    pub fn with_ld(mut self, ld: u32) -> Self {
602        self.ld = ld;
603        self
604    }
605
606    /// Total number of elements.
607    #[must_use]
608    pub fn numel(&self) -> usize {
609        self.rows as usize * self.cols as usize
610    }
611
612    /// Storage bytes (full stride, including padding from leading dimension).
613    #[must_use]
614    pub fn storage_bytes(&self) -> usize {
615        let major = match self.layout {
616            Layout::RowMajor => self.rows,
617            Layout::ColMajor => self.cols,
618        };
619        major as usize * self.ld as usize * T::SIZE
620    }
621
622    /// Effective dimensions after transpose.
623    #[must_use]
624    pub fn effective_dims(&self, trans: Transpose) -> (u32, u32) {
625        match trans {
626            Transpose::NoTrans => (self.rows, self.cols),
627            Transpose::Trans | Transpose::ConjTrans => (self.cols, self.rows),
628        }
629    }
630
631    /// Borrow as an immutable [`MatrixDesc`].
632    #[must_use]
633    pub fn as_immutable(&self) -> MatrixDesc<T> {
634        MatrixDesc {
635            ptr: self.ptr,
636            rows: self.rows,
637            cols: self.cols,
638            ld: self.ld,
639            layout: self.layout,
640            _phantom: PhantomData,
641        }
642    }
643}