1use ndarray::{Array1, Array2};
2
3use super::cpu_traits::MatrixLocation;
4
5#[derive(Clone, Debug)]
6pub struct DeviceBuffer<T> {
7 host_shadow: Vec<T>,
8 location: MatrixLocation,
9}
10
11impl<T> DeviceBuffer<T> {
12 pub const fn from_host_shadow(host_shadow: Vec<T>) -> Self {
13 Self {
14 host_shadow,
15 location: MatrixLocation::Host,
16 }
17 }
18
19 pub const fn len(&self) -> usize {
20 self.host_shadow.len()
21 }
22
23 pub const fn is_empty(&self) -> bool {
24 self.host_shadow.len() == 0
25 }
26
27 pub const fn location(&self) -> MatrixLocation {
28 self.location
29 }
30
31 pub fn host_shadow(&self) -> &[T] {
32 &self.host_shadow
33 }
34}
35
36#[derive(Clone, Debug)]
37pub struct DeviceVector {
38 pub len: usize,
39 pub data: DeviceBuffer<f64>,
40}
41
42impl DeviceVector {
43 pub fn from_array(array: &Array1<f64>) -> Self {
44 Self {
45 len: array.len(),
46 data: DeviceBuffer::from_host_shadow(array.to_vec()),
47 }
48 }
49}
50
51#[derive(Clone, Debug)]
52pub struct DeviceMatrix {
53 pub rows: usize,
54 pub cols: usize,
55 pub data: DeviceBuffer<f64>,
56 pub column_major: bool,
57}
58
59impl DeviceMatrix {
60 pub fn from_array(array: &Array2<f64>) -> Self {
61 Self {
62 rows: array.nrows(),
63 cols: array.ncols(),
64 data: DeviceBuffer::from_host_shadow(array.iter().copied().collect()),
65 column_major: false,
66 }
67 }
68
69 pub const fn bytes(&self) -> usize {
70 self.rows
71 .saturating_mul(self.cols)
72 .saturating_mul(std::mem::size_of::<f64>())
73 }
74}
75
76#[derive(Clone, Debug)]
77pub struct DeviceCsrMatrix {
78 pub rows: usize,
79 pub cols: usize,
80 pub rowptr: DeviceBuffer<i32>,
81 pub colidx: DeviceBuffer<i32>,
82 pub values: DeviceBuffer<f64>,
83}
84
85impl DeviceCsrMatrix {
86 pub const fn nnz(&self) -> usize {
87 self.values.len()
88 }
89}