Skip to main content

gam_gpu/
memory.rs

1use ndarray::{Array1, Array2};
2
3use super::cpu_traits::MatrixLocation;
4
5#[derive(Clone, Debug)]
6pub struct DeviceBuffer<T> {
7    host_shadow: Vec<T>,
8    location: MatrixLocation,
9}
10
11impl<T> DeviceBuffer<T> {
12    pub const fn from_host_shadow(host_shadow: Vec<T>) -> Self {
13        Self {
14            host_shadow,
15            location: MatrixLocation::Host,
16        }
17    }
18
19    pub const fn len(&self) -> usize {
20        self.host_shadow.len()
21    }
22
23    pub const fn is_empty(&self) -> bool {
24        self.host_shadow.len() == 0
25    }
26
27    pub const fn location(&self) -> MatrixLocation {
28        self.location
29    }
30
31    pub fn host_shadow(&self) -> &[T] {
32        &self.host_shadow
33    }
34}
35
36#[derive(Clone, Debug)]
37pub struct DeviceVector {
38    pub len: usize,
39    pub data: DeviceBuffer<f64>,
40}
41
42impl DeviceVector {
43    pub fn from_array(array: &Array1<f64>) -> Self {
44        Self {
45            len: array.len(),
46            data: DeviceBuffer::from_host_shadow(array.to_vec()),
47        }
48    }
49}
50
51#[derive(Clone, Debug)]
52pub struct DeviceMatrix {
53    pub rows: usize,
54    pub cols: usize,
55    pub data: DeviceBuffer<f64>,
56    pub column_major: bool,
57}
58
59impl DeviceMatrix {
60    pub fn from_array(array: &Array2<f64>) -> Self {
61        Self {
62            rows: array.nrows(),
63            cols: array.ncols(),
64            data: DeviceBuffer::from_host_shadow(array.iter().copied().collect()),
65            column_major: false,
66        }
67    }
68
69    pub const fn bytes(&self) -> usize {
70        self.rows
71            .saturating_mul(self.cols)
72            .saturating_mul(std::mem::size_of::<f64>())
73    }
74}
75
76#[derive(Clone, Debug)]
77pub struct DeviceCsrMatrix {
78    pub rows: usize,
79    pub cols: usize,
80    pub rowptr: DeviceBuffer<i32>,
81    pub colidx: DeviceBuffer<i32>,
82    pub values: DeviceBuffer<f64>,
83}
84
85impl DeviceCsrMatrix {
86    pub const fn nnz(&self) -> usize {
87        self.values.len()
88    }
89}