baracuda_kernels_types/matrix.rs
1//! Borrowed views of device-resident matrices and vectors.
2//!
3//! Pure data — no hidden device allocations, no driver state. Plans
4//! cache *selection metadata* on top of these descriptors but never own
5//! device memory.
6//!
7//! The type parameter `T` is bounded by [`DeviceRepr`] (not by the
8//! [`crate::Element`] / [`crate::IntElement`] traits) so the same view
9//! structs can be re-used across float kernels, integer kernels, and
10//! arbitrary scalar-typed bias / aux buffers. Semantic enforcement
11//! (which element types a given plan accepts) happens at the plan layer
12//! via the appropriate trait bound.
13
14use baracuda_driver::{DeviceSlice, DeviceSliceMut};
15use baracuda_types::DeviceRepr;
16
17/// Read-only view of a device-resident matrix.
18///
19/// `ld` is the leading dimension in **elements** (not bytes), measured
20/// along the major axis dictated by the layout: row-stride for row-major
21/// matrices, column-stride for column-major matrices.
22#[derive(Debug)]
23pub struct MatrixRef<'a, T: DeviceRepr + Copy + 'static> {
24 /// Device-resident element storage.
25 pub data: DeviceSlice<'a, T>,
26 /// Number of rows.
27 pub rows: i32,
28 /// Number of columns.
29 pub cols: i32,
30 /// Leading dimension in elements.
31 pub ld: i64,
32}
33
34/// Mutable view of a device-resident matrix (used for the output `D`).
35///
36/// See [`MatrixRef`] for the rationale behind the relaxed `T` bound.
37#[derive(Debug)]
38pub struct MatrixMut<'a, T: DeviceRepr + Copy + 'static> {
39 /// Device-resident element storage.
40 pub data: DeviceSliceMut<'a, T>,
41 /// Number of rows.
42 pub rows: i32,
43 /// Number of columns.
44 pub cols: i32,
45 /// Leading dimension in elements.
46 pub ld: i64,
47}
48
49/// Read-only view of a device-resident vector.
50///
51/// See [`MatrixRef`] for the rationale behind the relaxed `T` bound.
52#[derive(Debug)]
53pub struct VectorRef<'a, T: DeviceRepr + Copy + 'static> {
54 /// Device-resident element storage.
55 pub data: DeviceSlice<'a, T>,
56 /// Number of elements.
57 pub len: i32,
58 /// Stride in elements.
59 pub stride: i64,
60}