1use std::ptr;
2
3use singe_cuda_sys::{driver, runtime};
4
5use crate::{
6 graph::{Extent, Position},
7 memory::{ArrayHandle, MemoryCopyKind},
8 types::HostFunction,
9};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct PitchedPtr {
13 ptr: *mut (),
14 pub pitch: usize,
15 pub x_size: usize,
16 pub y_size: usize,
17}
18
19impl PitchedPtr {
20 pub const unsafe fn new(ptr: *mut (), pitch: usize, x_size: usize, y_size: usize) -> Self {
26 Self {
27 ptr,
28 pitch,
29 x_size,
30 y_size,
31 }
32 }
33
34 pub const fn ptr(self) -> *mut () {
35 self.ptr
36 }
37}
38
39#[derive(Debug, Clone, Copy)]
40pub struct MemoryCopy3DNodeParams {
41 src_array: Option<ArrayHandle>,
42 src_pos: Position,
43 src_ptr: PitchedPtr,
44 dst_array: Option<ArrayHandle>,
45 dst_pos: Position,
46 dst_ptr: PitchedPtr,
47 extent: Extent,
48 kind: MemoryCopyKind,
49}
50
51#[derive(Debug, Clone, Copy)]
52pub struct MemoryCopyToSymbolNodeParams {
53 symbol: *const (),
54 src: *const (),
55 count: usize,
56 offset: usize,
57 kind: MemoryCopyKind,
58}
59
60#[derive(Debug, Clone, Copy)]
61pub struct MemoryCopyFromSymbolNodeParams {
62 dst: *mut (),
63 symbol: *const (),
64 count: usize,
65 offset: usize,
66 kind: MemoryCopyKind,
67}
68
69#[derive(Debug, Clone, Copy)]
70pub struct HostNodeParams {
71 func: HostFunction,
72 user_data: *mut (),
73}
74
75impl HostNodeParams {
76 pub const unsafe fn new(func: HostFunction, user_data: *mut ()) -> Self {
82 Self { func, user_data }
83 }
84
85 pub const fn function(self) -> HostFunction {
86 self.func
87 }
88
89 pub const fn user_data(self) -> *mut () {
90 self.user_data
91 }
92}
93
94#[derive(Debug, Clone, Copy)]
95pub struct MemoryCopy1DNodeParams {
96 dst: *mut (),
97 src: *const (),
98 count: usize,
99 kind: MemoryCopyKind,
100}
101
102impl MemoryCopyToSymbolNodeParams {
103 pub const unsafe fn new(
110 symbol: *const (),
111 src: *const (),
112 count: usize,
113 offset: usize,
114 kind: MemoryCopyKind,
115 ) -> Self {
116 Self {
117 symbol,
118 src,
119 count,
120 offset,
121 kind,
122 }
123 }
124
125 pub const fn symbol(self) -> *const () {
126 self.symbol
127 }
128
129 pub const fn src(self) -> *const () {
130 self.src
131 }
132
133 pub const fn count(self) -> usize {
134 self.count
135 }
136
137 pub const fn offset(self) -> usize {
138 self.offset
139 }
140
141 pub const fn kind(self) -> MemoryCopyKind {
142 self.kind
143 }
144}
145
146impl MemoryCopyFromSymbolNodeParams {
147 pub const unsafe fn new(
155 dst: *mut (),
156 symbol: *const (),
157 count: usize,
158 offset: usize,
159 kind: MemoryCopyKind,
160 ) -> Self {
161 Self {
162 dst,
163 symbol,
164 count,
165 offset,
166 kind,
167 }
168 }
169
170 pub const fn dst(self) -> *mut () {
171 self.dst
172 }
173
174 pub const fn symbol(self) -> *const () {
175 self.symbol
176 }
177
178 pub const fn count(self) -> usize {
179 self.count
180 }
181
182 pub const fn offset(self) -> usize {
183 self.offset
184 }
185
186 pub const fn kind(self) -> MemoryCopyKind {
187 self.kind
188 }
189}
190
191impl MemoryCopy1DNodeParams {
192 pub const unsafe fn new(
198 dst: *mut (),
199 src: *const (),
200 count: usize,
201 kind: MemoryCopyKind,
202 ) -> Self {
203 Self {
204 dst,
205 src,
206 count,
207 kind,
208 }
209 }
210
211 pub const fn dst(self) -> *mut () {
212 self.dst
213 }
214
215 pub const fn src(self) -> *const () {
216 self.src
217 }
218
219 pub const fn count(self) -> usize {
220 self.count
221 }
222
223 pub const fn kind(self) -> MemoryCopyKind {
224 self.kind
225 }
226}
227
228impl MemoryCopy3DNodeParams {
229 pub const unsafe fn new(
238 src_array: Option<ArrayHandle>,
239 src_pos: Position,
240 src_ptr: PitchedPtr,
241 dst_array: Option<ArrayHandle>,
242 dst_pos: Position,
243 dst_ptr: PitchedPtr,
244 extent: Extent,
245 kind: MemoryCopyKind,
246 ) -> Self {
247 Self {
248 src_array,
249 src_pos,
250 src_ptr,
251 dst_array,
252 dst_pos,
253 dst_ptr,
254 extent,
255 kind,
256 }
257 }
258
259 pub const fn src_array(self) -> Option<ArrayHandle> {
260 self.src_array
261 }
262
263 pub const fn src_pos(self) -> Position {
264 self.src_pos
265 }
266
267 pub const fn src_ptr(self) -> PitchedPtr {
268 self.src_ptr
269 }
270
271 pub const fn dst_array(self) -> Option<ArrayHandle> {
272 self.dst_array
273 }
274
275 pub const fn dst_pos(self) -> Position {
276 self.dst_pos
277 }
278
279 pub const fn dst_ptr(self) -> PitchedPtr {
280 self.dst_ptr
281 }
282
283 pub const fn extent(self) -> Extent {
284 self.extent
285 }
286
287 pub const fn kind(self) -> MemoryCopyKind {
288 self.kind
289 }
290}
291
292impl From<PitchedPtr> for runtime::cudaPitchedPtr {
293 fn from(value: PitchedPtr) -> Self {
294 Self {
295 ptr: value.ptr().cast(),
296 pitch: value.pitch as _,
297 xsize: value.x_size as _,
298 ysize: value.y_size as _,
299 }
300 }
301}
302
303impl From<&MemoryCopy3DNodeParams> for runtime::cudaMemcpy3DParms {
304 fn from(value: &MemoryCopy3DNodeParams) -> Self {
305 Self {
306 srcArray: value
307 .src_array()
308 .map_or(ptr::null_mut(), ArrayHandle::as_raw),
309 srcPos: value.src_pos().into(),
310 srcPtr: value.src_ptr().into(),
311 dstArray: value
312 .dst_array()
313 .map_or(ptr::null_mut(), ArrayHandle::as_raw),
314 dstPos: value.dst_pos().into(),
315 dstPtr: value.dst_ptr().into(),
316 extent: value.extent().into(),
317 kind: value.kind().into(),
318 }
319 }
320}
321
322impl From<&HostNodeParams> for driver::CUDA_HOST_NODE_PARAMS {
323 fn from(value: &HostNodeParams) -> Self {
324 Self {
325 fn_: value.function().as_raw(),
326 userData: value.user_data().cast(),
327 }
328 }
329}