Skip to main content

singe_cusolver/
layout.rs

1use std::marker::PhantomData;
2
3use singe_cuda::{memory::DeviceMemory, types::DevicePtr};
4
5/// A reference to a matrix stored on the device.
6#[derive(Debug, Clone, Copy)]
7pub struct MatrixRef<'a, T> {
8    pub data: &'a DeviceMemory<T>,
9    pub leading_dimension: usize,
10}
11
12/// A mutable reference to a matrix stored on the device.
13#[derive(Debug)]
14pub struct MatrixMut<'a, T> {
15    pub data: &'a mut DeviceMemory<T>,
16    pub leading_dimension: usize,
17}
18
19/// A reference to a strided batched matrix stored on the device.
20#[derive(Debug, Clone, Copy)]
21pub struct StridedBatchedMatrixRef<'a, T> {
22    pub data: &'a DeviceMemory<T>,
23    pub leading_dimension: usize,
24    pub stride: usize,
25}
26
27/// A mutable reference to a strided batched matrix stored on the device.
28#[derive(Debug)]
29pub struct StridedBatchedMatrixMut<'a, T> {
30    pub data: &'a mut DeviceMemory<T>,
31    pub leading_dimension: usize,
32    pub stride: usize,
33}
34
35/// A reference to a vector stored on the device.
36#[derive(Debug, Clone, Copy)]
37pub struct VectorRef<'a, T> {
38    pub data: &'a DeviceMemory<T>,
39}
40
41/// A mutable reference to a vector stored on the device.
42#[derive(Debug)]
43pub struct VectorMut<'a, T> {
44    pub data: &'a mut DeviceMemory<T>,
45}
46
47/// A reference to a strided batched vector stored on the device.
48#[derive(Debug, Clone, Copy)]
49pub struct StridedBatchedVectorRef<'a, T> {
50    pub data: &'a DeviceMemory<T>,
51    pub stride: usize,
52}
53
54/// A mutable reference to a strided batched vector stored on the device.
55#[derive(Debug)]
56pub struct StridedBatchedVectorMut<'a, T> {
57    pub data: &'a mut DeviceMemory<T>,
58    pub stride: usize,
59}
60
61/// A reference to a batched matrix (pointer array) stored on the device.
62#[derive(Debug, Clone, Copy)]
63pub struct BatchedMatrixRef<'a, T> {
64    pub pointers: &'a DeviceMemory<DevicePtr>,
65    pub leading_dimension: usize,
66    _phantom: PhantomData<T>,
67}
68
69/// A reference to a batched vector (pointer array) stored on the device.
70#[derive(Debug, Clone, Copy)]
71pub struct BatchedVectorRef<'a, T> {
72    pub pointers: &'a DeviceMemory<DevicePtr>,
73    pub leading_dimension: usize,
74    _phantom: PhantomData<T>,
75}
76
77/// A mutable reference to a byte workspace buffer on the device.
78#[derive(Debug)]
79pub struct ByteWorkspaceMut<'a> {
80    pub device: &'a mut DeviceMemory<u8>,
81    pub host: &'a mut [u8],
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub struct WorkspaceSizes {
86    pub device_bytes: usize,
87    pub host_bytes: usize,
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub struct SelectionWorkspaceSizes {
92    pub selection_size: usize,
93    pub workspace: WorkspaceSizes,
94}
95
96impl<'a, T> MatrixRef<'a, T> {
97    pub const fn new(data: &'a DeviceMemory<T>, leading_dimension: usize) -> Self {
98        Self {
99            data,
100            leading_dimension,
101        }
102    }
103}
104
105impl<'a, T> MatrixMut<'a, T> {
106    pub const fn new(data: &'a mut DeviceMemory<T>, leading_dimension: usize) -> Self {
107        Self {
108            data,
109            leading_dimension,
110        }
111    }
112
113    pub fn as_ref(&self) -> MatrixRef<'_, T> {
114        MatrixRef::new(self.data, self.leading_dimension)
115    }
116}
117
118impl<'a, T> StridedBatchedMatrixRef<'a, T> {
119    pub const fn new(data: &'a DeviceMemory<T>, leading_dimension: usize, stride: usize) -> Self {
120        Self {
121            data,
122            leading_dimension,
123            stride,
124        }
125    }
126}
127
128impl<'a, T> StridedBatchedMatrixMut<'a, T> {
129    pub const fn new(
130        data: &'a mut DeviceMemory<T>,
131        leading_dimension: usize,
132        stride: usize,
133    ) -> Self {
134        Self {
135            data,
136            leading_dimension,
137            stride,
138        }
139    }
140
141    pub fn as_ref(&self) -> StridedBatchedMatrixRef<'_, T> {
142        StridedBatchedMatrixRef::new(self.data, self.leading_dimension, self.stride)
143    }
144}
145
146impl<'a, T> VectorRef<'a, T> {
147    pub const fn new(data: &'a DeviceMemory<T>) -> Self {
148        Self { data }
149    }
150}
151
152impl<'a, T> VectorMut<'a, T> {
153    pub const fn new(data: &'a mut DeviceMemory<T>) -> Self {
154        Self { data }
155    }
156
157    pub fn as_ref(&self) -> VectorRef<'_, T> {
158        VectorRef::new(self.data)
159    }
160}
161
162impl<'a, T> StridedBatchedVectorRef<'a, T> {
163    pub const fn new(data: &'a DeviceMemory<T>, stride: usize) -> Self {
164        Self { data, stride }
165    }
166}
167
168impl<'a, T> StridedBatchedVectorMut<'a, T> {
169    pub const fn new(data: &'a mut DeviceMemory<T>, stride: usize) -> Self {
170        Self { data, stride }
171    }
172
173    pub fn as_ref(&self) -> StridedBatchedVectorRef<'_, T> {
174        StridedBatchedVectorRef::new(self.data, self.stride)
175    }
176}
177
178impl<'a, T> BatchedMatrixRef<'a, T> {
179    pub const fn new(pointers: &'a DeviceMemory<DevicePtr>, leading_dimension: usize) -> Self {
180        Self {
181            pointers,
182            leading_dimension,
183            _phantom: PhantomData,
184        }
185    }
186
187    pub const fn len(&self) -> usize {
188        self.pointers.len()
189    }
190
191    pub const fn is_empty(&self) -> bool {
192        self.pointers.is_empty()
193    }
194
195    pub const fn as_mut_ptr(&self) -> *mut *mut T {
196        self.pointers.as_ptr().cast::<*mut T>().cast_mut()
197    }
198}
199
200impl<'a, T> BatchedVectorRef<'a, T> {
201    pub const fn new(pointers: &'a DeviceMemory<DevicePtr>, leading_dimension: usize) -> Self {
202        Self {
203            pointers,
204            leading_dimension,
205            _phantom: PhantomData,
206        }
207    }
208
209    pub const fn len(&self) -> usize {
210        self.pointers.len()
211    }
212
213    pub const fn is_empty(&self) -> bool {
214        self.pointers.is_empty()
215    }
216
217    pub const fn as_mut_ptr(&self) -> *mut *mut T {
218        self.pointers.as_ptr().cast::<*mut T>().cast_mut()
219    }
220}
221
222impl<'a> ByteWorkspaceMut<'a> {
223    pub const fn new(device: &'a mut DeviceMemory<u8>, host: &'a mut [u8]) -> Self {
224        Self { device, host }
225    }
226}
227
228impl WorkspaceSizes {
229    pub const fn new(device_bytes: usize, host_bytes: usize) -> Self {
230        Self {
231            device_bytes,
232            host_bytes,
233        }
234    }
235}
236
237impl SelectionWorkspaceSizes {
238    pub const fn new(selection_size: usize, device_bytes: usize, host_bytes: usize) -> Self {
239        Self {
240            selection_size,
241            workspace: WorkspaceSizes::new(device_bytes, host_bytes),
242        }
243    }
244}