Skip to main content

singe_cuda/graph/
raw.rs

1use std::ptr;
2
3use singe_cuda_sys::{driver, runtime};
4
5use crate::{
6    graph::{Extent, Position},
7    memory::{ArrayHandle, MemoryCopyKind},
8    types::HostFunction,
9};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct PitchedPtr {
13    ptr: *mut (),
14    pub pitch: usize,
15    pub x_size: usize,
16    pub y_size: usize,
17}
18
19impl PitchedPtr {
20    /// Creates pitched pointer parameters from a raw device or mapped host pointer.
21    ///
22    /// # Safety
23    ///
24    /// `ptr` must be valid for every row described by `pitch`, `x_size`, and `y_size` when CUDA evaluates the graph node using this value.
25    pub const unsafe fn new(ptr: *mut (), pitch: usize, x_size: usize, y_size: usize) -> Self {
26        Self {
27            ptr,
28            pitch,
29            x_size,
30            y_size,
31        }
32    }
33
34    pub const fn ptr(self) -> *mut () {
35        self.ptr
36    }
37}
38
39#[derive(Debug, Clone, Copy)]
40pub struct MemoryCopy3DNodeParams {
41    src_array: Option<ArrayHandle>,
42    src_pos: Position,
43    src_ptr: PitchedPtr,
44    dst_array: Option<ArrayHandle>,
45    dst_pos: Position,
46    dst_ptr: PitchedPtr,
47    extent: Extent,
48    kind: MemoryCopyKind,
49}
50
51#[derive(Debug, Clone, Copy)]
52pub struct MemoryCopyToSymbolNodeParams {
53    symbol: *const (),
54    src: *const (),
55    count: usize,
56    offset: usize,
57    kind: MemoryCopyKind,
58}
59
60#[derive(Debug, Clone, Copy)]
61pub struct MemoryCopyFromSymbolNodeParams {
62    dst: *mut (),
63    symbol: *const (),
64    count: usize,
65    offset: usize,
66    kind: MemoryCopyKind,
67}
68
69#[derive(Debug, Clone, Copy)]
70pub struct HostNodeParams {
71    func: HostFunction,
72    user_data: *mut (),
73}
74
75impl HostNodeParams {
76    /// Creates host callback node parameters from a raw user-data pointer.
77    ///
78    /// # Safety
79    ///
80    /// `user_data` must remain valid for `func` according to CUDA host-node callback rules until no graph execution can invoke the callback.
81    pub const unsafe fn new(func: HostFunction, user_data: *mut ()) -> Self {
82        Self { func, user_data }
83    }
84
85    pub const fn function(self) -> HostFunction {
86        self.func
87    }
88
89    pub const fn user_data(self) -> *mut () {
90        self.user_data
91    }
92}
93
94#[derive(Debug, Clone, Copy)]
95pub struct MemoryCopy1DNodeParams {
96    dst: *mut (),
97    src: *const (),
98    count: usize,
99    kind: MemoryCopyKind,
100}
101
102impl MemoryCopyToSymbolNodeParams {
103    /// Creates symbol-copy node parameters from a raw source pointer.
104    ///
105    /// # Safety
106    ///
107    /// `symbol` must identify a valid CUDA symbol and `src` must be valid
108    /// for reads of `count` bytes according to `kind` when CUDA evaluates the graph node using this value.
109    pub const unsafe fn new(
110        symbol: *const (),
111        src: *const (),
112        count: usize,
113        offset: usize,
114        kind: MemoryCopyKind,
115    ) -> Self {
116        Self {
117            symbol,
118            src,
119            count,
120            offset,
121            kind,
122        }
123    }
124
125    pub const fn symbol(self) -> *const () {
126        self.symbol
127    }
128
129    pub const fn src(self) -> *const () {
130        self.src
131    }
132
133    pub const fn count(self) -> usize {
134        self.count
135    }
136
137    pub const fn offset(self) -> usize {
138        self.offset
139    }
140
141    pub const fn kind(self) -> MemoryCopyKind {
142        self.kind
143    }
144}
145
146impl MemoryCopyFromSymbolNodeParams {
147    /// Creates symbol-copy node parameters from a raw destination pointer.
148    ///
149    /// # Safety
150    ///
151    /// `symbol` must identify a valid CUDA symbol and `dst` must be valid
152    /// for writes of `count` bytes according to `kind` when CUDA evaluates
153    /// the graph node using this value.
154    pub const unsafe fn new(
155        dst: *mut (),
156        symbol: *const (),
157        count: usize,
158        offset: usize,
159        kind: MemoryCopyKind,
160    ) -> Self {
161        Self {
162            dst,
163            symbol,
164            count,
165            offset,
166            kind,
167        }
168    }
169
170    pub const fn dst(self) -> *mut () {
171        self.dst
172    }
173
174    pub const fn symbol(self) -> *const () {
175        self.symbol
176    }
177
178    pub const fn count(self) -> usize {
179        self.count
180    }
181
182    pub const fn offset(self) -> usize {
183        self.offset
184    }
185
186    pub const fn kind(self) -> MemoryCopyKind {
187        self.kind
188    }
189}
190
191impl MemoryCopy1DNodeParams {
192    /// Creates one-dimensional memcpy node parameters from raw pointers.
193    ///
194    /// # Safety
195    ///
196    /// `dst` and `src` must be valid for `count` bytes according to `kind` when CUDA evaluates the graph node using this value.
197    pub const unsafe fn new(
198        dst: *mut (),
199        src: *const (),
200        count: usize,
201        kind: MemoryCopyKind,
202    ) -> Self {
203        Self {
204            dst,
205            src,
206            count,
207            kind,
208        }
209    }
210
211    pub const fn dst(self) -> *mut () {
212        self.dst
213    }
214
215    pub const fn src(self) -> *const () {
216        self.src
217    }
218
219    pub const fn count(self) -> usize {
220        self.count
221    }
222
223    pub const fn kind(self) -> MemoryCopyKind {
224        self.kind
225    }
226}
227
228impl MemoryCopy3DNodeParams {
229    /// Creates three-dimensional memcpy node parameters from raw array and
230    /// pitched pointer operands.
231    ///
232    /// # Safety
233    ///
234    /// Every array and pointer operand selected by `kind` must remain
235    /// valid for the region described by the positions, pitched pointers,
236    /// and extent whenever CUDA evaluates a graph node using this value.
237    pub const unsafe fn new(
238        src_array: Option<ArrayHandle>,
239        src_pos: Position,
240        src_ptr: PitchedPtr,
241        dst_array: Option<ArrayHandle>,
242        dst_pos: Position,
243        dst_ptr: PitchedPtr,
244        extent: Extent,
245        kind: MemoryCopyKind,
246    ) -> Self {
247        Self {
248            src_array,
249            src_pos,
250            src_ptr,
251            dst_array,
252            dst_pos,
253            dst_ptr,
254            extent,
255            kind,
256        }
257    }
258
259    pub const fn src_array(self) -> Option<ArrayHandle> {
260        self.src_array
261    }
262
263    pub const fn src_pos(self) -> Position {
264        self.src_pos
265    }
266
267    pub const fn src_ptr(self) -> PitchedPtr {
268        self.src_ptr
269    }
270
271    pub const fn dst_array(self) -> Option<ArrayHandle> {
272        self.dst_array
273    }
274
275    pub const fn dst_pos(self) -> Position {
276        self.dst_pos
277    }
278
279    pub const fn dst_ptr(self) -> PitchedPtr {
280        self.dst_ptr
281    }
282
283    pub const fn extent(self) -> Extent {
284        self.extent
285    }
286
287    pub const fn kind(self) -> MemoryCopyKind {
288        self.kind
289    }
290}
291
292impl From<PitchedPtr> for runtime::cudaPitchedPtr {
293    fn from(value: PitchedPtr) -> Self {
294        Self {
295            ptr: value.ptr().cast(),
296            pitch: value.pitch as _,
297            xsize: value.x_size as _,
298            ysize: value.y_size as _,
299        }
300    }
301}
302
303impl From<&MemoryCopy3DNodeParams> for runtime::cudaMemcpy3DParms {
304    fn from(value: &MemoryCopy3DNodeParams) -> Self {
305        Self {
306            srcArray: value
307                .src_array()
308                .map_or(ptr::null_mut(), ArrayHandle::as_raw),
309            srcPos: value.src_pos().into(),
310            srcPtr: value.src_ptr().into(),
311            dstArray: value
312                .dst_array()
313                .map_or(ptr::null_mut(), ArrayHandle::as_raw),
314            dstPos: value.dst_pos().into(),
315            dstPtr: value.dst_ptr().into(),
316            extent: value.extent().into(),
317            kind: value.kind().into(),
318        }
319    }
320}
321
322impl From<&HostNodeParams> for driver::CUDA_HOST_NODE_PARAMS {
323    fn from(value: &HostNodeParams) -> Self {
324        Self {
325            fn_: value.function().as_raw(),
326            userData: value.user_data().cast(),
327        }
328    }
329}