1use std::marker::PhantomData;
2
3use singe_cuda::{memory::DeviceMemory, types::DevicePtr};
4
5#[derive(Debug, Clone, Copy)]
7pub struct MatrixRef<'a, T> {
8 pub data: &'a DeviceMemory<T>,
9 pub leading_dimension: usize,
10}
11
12#[derive(Debug)]
14pub struct MatrixMut<'a, T> {
15 pub data: &'a mut DeviceMemory<T>,
16 pub leading_dimension: usize,
17}
18
19#[derive(Debug, Clone, Copy)]
21pub struct StridedBatchedMatrixRef<'a, T> {
22 pub data: &'a DeviceMemory<T>,
23 pub leading_dimension: usize,
24 pub stride: usize,
25}
26
27#[derive(Debug)]
29pub struct StridedBatchedMatrixMut<'a, T> {
30 pub data: &'a mut DeviceMemory<T>,
31 pub leading_dimension: usize,
32 pub stride: usize,
33}
34
35#[derive(Debug, Clone, Copy)]
37pub struct VectorRef<'a, T> {
38 pub data: &'a DeviceMemory<T>,
39}
40
41#[derive(Debug)]
43pub struct VectorMut<'a, T> {
44 pub data: &'a mut DeviceMemory<T>,
45}
46
47#[derive(Debug, Clone, Copy)]
49pub struct StridedBatchedVectorRef<'a, T> {
50 pub data: &'a DeviceMemory<T>,
51 pub stride: usize,
52}
53
54#[derive(Debug)]
56pub struct StridedBatchedVectorMut<'a, T> {
57 pub data: &'a mut DeviceMemory<T>,
58 pub stride: usize,
59}
60
61#[derive(Debug, Clone, Copy)]
63pub struct BatchedMatrixRef<'a, T> {
64 pub pointers: &'a DeviceMemory<DevicePtr>,
65 pub leading_dimension: usize,
66 _phantom: PhantomData<T>,
67}
68
69#[derive(Debug, Clone, Copy)]
71pub struct BatchedVectorRef<'a, T> {
72 pub pointers: &'a DeviceMemory<DevicePtr>,
73 pub leading_dimension: usize,
74 _phantom: PhantomData<T>,
75}
76
77#[derive(Debug)]
79pub struct ByteWorkspaceMut<'a> {
80 pub device: &'a mut DeviceMemory<u8>,
81 pub host: &'a mut [u8],
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub struct WorkspaceSizes {
86 pub device_bytes: usize,
87 pub host_bytes: usize,
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub struct SelectionWorkspaceSizes {
92 pub selection_size: usize,
93 pub workspace: WorkspaceSizes,
94}
95
96impl<'a, T> MatrixRef<'a, T> {
97 pub const fn new(data: &'a DeviceMemory<T>, leading_dimension: usize) -> Self {
98 Self {
99 data,
100 leading_dimension,
101 }
102 }
103}
104
105impl<'a, T> MatrixMut<'a, T> {
106 pub const fn new(data: &'a mut DeviceMemory<T>, leading_dimension: usize) -> Self {
107 Self {
108 data,
109 leading_dimension,
110 }
111 }
112
113 pub fn as_ref(&self) -> MatrixRef<'_, T> {
114 MatrixRef::new(self.data, self.leading_dimension)
115 }
116}
117
118impl<'a, T> StridedBatchedMatrixRef<'a, T> {
119 pub const fn new(data: &'a DeviceMemory<T>, leading_dimension: usize, stride: usize) -> Self {
120 Self {
121 data,
122 leading_dimension,
123 stride,
124 }
125 }
126}
127
128impl<'a, T> StridedBatchedMatrixMut<'a, T> {
129 pub const fn new(
130 data: &'a mut DeviceMemory<T>,
131 leading_dimension: usize,
132 stride: usize,
133 ) -> Self {
134 Self {
135 data,
136 leading_dimension,
137 stride,
138 }
139 }
140
141 pub fn as_ref(&self) -> StridedBatchedMatrixRef<'_, T> {
142 StridedBatchedMatrixRef::new(self.data, self.leading_dimension, self.stride)
143 }
144}
145
146impl<'a, T> VectorRef<'a, T> {
147 pub const fn new(data: &'a DeviceMemory<T>) -> Self {
148 Self { data }
149 }
150}
151
152impl<'a, T> VectorMut<'a, T> {
153 pub const fn new(data: &'a mut DeviceMemory<T>) -> Self {
154 Self { data }
155 }
156
157 pub fn as_ref(&self) -> VectorRef<'_, T> {
158 VectorRef::new(self.data)
159 }
160}
161
162impl<'a, T> StridedBatchedVectorRef<'a, T> {
163 pub const fn new(data: &'a DeviceMemory<T>, stride: usize) -> Self {
164 Self { data, stride }
165 }
166}
167
168impl<'a, T> StridedBatchedVectorMut<'a, T> {
169 pub const fn new(data: &'a mut DeviceMemory<T>, stride: usize) -> Self {
170 Self { data, stride }
171 }
172
173 pub fn as_ref(&self) -> StridedBatchedVectorRef<'_, T> {
174 StridedBatchedVectorRef::new(self.data, self.stride)
175 }
176}
177
178impl<'a, T> BatchedMatrixRef<'a, T> {
179 pub const fn new(pointers: &'a DeviceMemory<DevicePtr>, leading_dimension: usize) -> Self {
180 Self {
181 pointers,
182 leading_dimension,
183 _phantom: PhantomData,
184 }
185 }
186
187 pub const fn len(&self) -> usize {
188 self.pointers.len()
189 }
190
191 pub const fn is_empty(&self) -> bool {
192 self.pointers.is_empty()
193 }
194
195 pub const fn as_mut_ptr(&self) -> *mut *mut T {
196 self.pointers.as_ptr().cast::<*mut T>().cast_mut()
197 }
198}
199
200impl<'a, T> BatchedVectorRef<'a, T> {
201 pub const fn new(pointers: &'a DeviceMemory<DevicePtr>, leading_dimension: usize) -> Self {
202 Self {
203 pointers,
204 leading_dimension,
205 _phantom: PhantomData,
206 }
207 }
208
209 pub const fn len(&self) -> usize {
210 self.pointers.len()
211 }
212
213 pub const fn is_empty(&self) -> bool {
214 self.pointers.is_empty()
215 }
216
217 pub const fn as_mut_ptr(&self) -> *mut *mut T {
218 self.pointers.as_ptr().cast::<*mut T>().cast_mut()
219 }
220}
221
222impl<'a> ByteWorkspaceMut<'a> {
223 pub const fn new(device: &'a mut DeviceMemory<u8>, host: &'a mut [u8]) -> Self {
224 Self { device, host }
225 }
226}
227
228impl WorkspaceSizes {
229 pub const fn new(device_bytes: usize, host_bytes: usize) -> Self {
230 Self {
231 device_bytes,
232 host_bytes,
233 }
234 }
235}
236
237impl SelectionWorkspaceSizes {
238 pub const fn new(selection_size: usize, device_bytes: usize, host_bytes: usize) -> Self {
239 Self {
240 selection_size,
241 workspace: WorkspaceSizes::new(device_bytes, host_bytes),
242 }
243 }
244}