cudarc/driver/safe/
unified_memory.rs

1use core::marker::PhantomData;
2use std::sync::Arc;
3
4use crate::driver::{result, sys};
5
6use super::{
7    CudaContext, CudaEvent, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr, DeviceSlice,
8    DriverError, HostSlice, LaunchArgs, PushKernelArg, ValidAsZeroBits,
9};
10
11/// Unified memory allocated with [CudaContext::alloc_unified()] (via [cuMemAllocManaged](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gb347ded34dc326af404aa02af5388a32)).
12///
13/// This is memory that can be accessed by host side (rust code) AND device side kernels. For host side access you can read/write using
14/// [UnifiedSlice::as_slice()]/[UnifiedSlice::as_mut_slice()]. You can read/write host side no matter what attach mode you set
15/// (via [UnifiedSlice::attach()], or the value you use to create the slice in [CudaContext::alloc_unified()]).
16///
17/// This struct also implements [HostSlice] and [DeviceSlice], meaning you can use it with various [CudaStream] related calls for doing memcpy/memset operations.
18///
19/// Finally, it implements [PushKernelArg], so you can pass it as a device pointer to a kernel.
20///
21/// For any device access, the restrictions are a bit more complicated depending on the attach mode:
22/// 1. [sys::CUmemAttach_flags::CU_MEM_ATTACH_HOST] - a device can ONLY access if [sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY] is non-zero.
23/// 2. [sys::CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL] - any device/stream can access it.
24/// 3. [sys::CUmemAttach_flags::CU_MEM_ATTACH_SINGLE] - only the stream you attach it to can access it. Additionally, accessing on the CPU synchronizes the associated stream.
25///
26/// See [cuda docs for Unified Addressing/Unified Memory](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__UNIFIED.html#group__CUDA__UNIFIED)
27///
28/// # Thread safety
29///
30/// This is thread safe
31#[derive(Debug)]
32pub struct UnifiedSlice<T> {
33    pub(crate) cu_device_ptr: sys::CUdeviceptr,
34    pub(crate) len: usize,
35    pub(crate) stream: Arc<CudaStream>,
36    pub(crate) event: CudaEvent,
37    pub(crate) attach_mode: sys::CUmemAttach_flags,
38    pub(crate) concurrent_managed_access: bool,
39    pub(crate) marker: PhantomData<*const T>,
40}
41
42unsafe impl<T> Send for UnifiedSlice<T> {}
43unsafe impl<T> Sync for UnifiedSlice<T> {}
44
45impl<T> Drop for UnifiedSlice<T> {
46    fn drop(&mut self) {
47        self.stream.ctx.record_err(self.event.synchronize());
48        self.stream
49            .ctx
50            .record_err(unsafe { result::memory_free(self.cu_device_ptr) });
51    }
52}
53
54impl CudaContext {
55    /// Allocates managed memory using [cuMemAllocManaged](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gb347ded34dc326af404aa02af5388a32).
56    ///
57    /// If `attach_global` is true, then allocates the memory with flag [sys::CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL],
58    /// otherwise uses flag [sys::CUmemAttach_flags::CU_MEM_ATTACH_HOST].
59    ///
60    /// Note that only these two flags are valid during allocation, you can change the
61    /// attach mode later via [UnifiedSlice::attach()]
62    ///
63    /// If the device does not support managed memory ([sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY] is 0),
64    /// then this method will return Err with [sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED].
65    ///
66    /// # Safety
67    ///
68    /// This is unsafe because this method has no restrictions that `T` is valid for any bit pattern.
69    pub unsafe fn alloc_unified<T: DeviceRepr>(
70        self: &Arc<Self>,
71        len: usize,
72        attach_global: bool,
73    ) -> Result<UnifiedSlice<T>, DriverError> {
74        // NOTE: The pointer is valid on the CPU and on all GPUs in the system that support managed memory.
75        if self.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY)? == 0 {
76            return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED));
77        }
78
79        let attach_mode = if attach_global {
80            sys::CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL
81        } else {
82            sys::CUmemAttach_flags::CU_MEM_ATTACH_HOST
83        };
84
85        let cu_device_ptr = result::malloc_managed(len * std::mem::size_of::<T>(), attach_mode)?;
86        let concurrent_managed_access = self
87            .attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS)?
88            != 0;
89
90        let stream = self.default_stream();
91        let event = self.new_event(Some(sys::CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
92
93        Ok(UnifiedSlice {
94            cu_device_ptr,
95            len,
96            stream,
97            event,
98            attach_mode,
99            concurrent_managed_access,
100            marker: PhantomData,
101        })
102    }
103}
104
105impl<T> UnifiedSlice<T> {
106    pub fn len(&self) -> usize {
107        self.len
108    }
109
110    pub fn is_empty(&self) -> bool {
111        self.len == 0
112    }
113
114    pub fn attach_mode(&self) -> sys::CUmemAttach_flags {
115        self.attach_mode
116    }
117
118    pub fn num_bytes(&self) -> usize {
119        self.len * std::mem::size_of::<T>()
120    }
121
122    /// See [cuStreamAttachMemAsync cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g6e468d680e263e7eba02a56643c50533)
123    ///
124    /// NOTE: if stream is the null stream, then cuda will return an error.
125    pub fn attach(
126        &mut self,
127        stream: &Arc<CudaStream>,
128        flags: sys::CUmemAttach_flags,
129    ) -> Result<(), DriverError> {
130        self.event.synchronize()?;
131        self.stream = stream.clone();
132        self.attach_mode = flags;
133        unsafe {
134            result::stream::attach_mem_async(
135                self.stream.cu_stream,
136                self.cu_device_ptr,
137                self.num_bytes(),
138                self.attach_mode,
139            )
140        }
141    }
142
143    /// See [cuMemPrefetchAsync_v2 cuda docs](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__UNIFIED.html#group__CUDA__UNIFIED_1gaf4f188a71891ad6a71fdd2850c8d638)
144    #[cfg(not(any(
145        feature = "cuda-11040",
146        feature = "cuda-11050",
147        feature = "cuda-11060",
148        feature = "cuda-11070",
149        feature = "cuda-11080",
150        feature = "cuda-12000",
151        feature = "cuda-12010"
152    )))]
153    pub fn prefetch(&self) -> Result<(), DriverError> {
154        let location = match self.attach_mode {
155            sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_GLOBAL
156            | sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_SINGLE => {
157                // > Specifying CU_MEM_LOCATION_TYPE_DEVICE for CUmemLocation::type will prefetch memory to GPU specified by device ordinal CUmemLocation::id which must have non-zero value for the device attribute CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. Additionally, hStream must be associated with a device that has a non-zero value for the device attribute CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS.
158                if !self.concurrent_managed_access {
159                    return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED));
160                }
161                sys::CUmemLocation {
162                    type_: sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
163                    id: self.stream.ctx.ordinal as i32,
164                }
165            }
166            sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_HOST => {
167                // > Specifying CU_MEM_LOCATION_TYPE_HOST as CUmemLocation::type will prefetch data to host memory. Applications can request prefetching memory to a specific host NUMA node by specifying CU_MEM_LOCATION_TYPE_HOST_NUMA for CUmemLocation::type and a valid host NUMA node id in CUmemLocation::id Users can also request prefetching memory to the host NUMA node closest to the current thread's CPU by specifying CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT for CUmemLocation::type.
168                sys::CUmemLocation {
169                    type_: sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT,
170                    id: 0, // NOTE: ignored
171                }
172            }
173        };
174        unsafe {
175            result::mem_prefetch_async(
176                self.cu_device_ptr,
177                self.len * std::mem::size_of::<T>(),
178                location,
179                self.stream.cu_stream,
180            )
181        }
182    }
183
184    pub fn check_host_access(&self) -> Result<(), DriverError> {
185        match self.attach_mode {
186            sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_GLOBAL => {
187                // NOTE: can't find info about this case in the docs anywhere. It is easy to assume
188                // that since SINGLE needs the stream synchronized to access, than GLOBAL might need the whole context
189                // synchronized. But unable to confirm this assumption
190            }
191            sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_HOST => {
192                // NOTE: Most of the docs talk about device access when HOST is specified, but unable to find
193                // anything on constraints for CPU access.
194            }
195            sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_SINGLE => {
196                // > When memory is associated with a single stream, the Unified Memory system will allow CPU access to this memory region so long as all operations in hStream have completed, regardless of whether other streams are active. In effect, this constrains exclusive ownership of the managed memory region by an active GPU to per-stream activity instead of whole-GPU activity.
197                self.stream.synchronize()?;
198            }
199        };
200        Ok(())
201    }
202
203    pub fn check_device_access(&self, stream: &CudaStream) -> Result<(), DriverError> {
204        // > Accessing memory on the device from streams that are not associated with it will produce undefined results. No error checking is performed by the Unified Memory system to ensure that kernels launched into other streams do not access this region.
205        match self.attach_mode {
206            sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_GLOBAL => {
207                // NOTE: no checks needed here, because any context/stream can access when GLOBAL mode is used.
208            }
209            sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_HOST => {
210                // > If CU_MEM_ATTACH_HOST is specified, then the allocation should not be accessed from devices that have a zero value for the device attribute CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS;
211                // > If the CU_MEM_ATTACH_HOST flag is specified, the program makes a guarantee that it won't access the memory on the device from any stream on a device that has a zero value for the device attribute CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS
212                let concurrent_managed_access = if self.stream.context() != stream.context() {
213                    // if we are going to access in a different context, we need to check for concurrent managed access
214                    stream.context().attribute(
215                        sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS,
216                    )? != 0
217                } else {
218                    // otherwise we can use the cached value for the attribute
219                    self.concurrent_managed_access
220                };
221                if !concurrent_managed_access {
222                    return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED));
223                }
224            }
225            sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_SINGLE => {
226                // > If the CU_MEM_ATTACH_SINGLE flag is specified and hStream is associated with a device that has a zero value for the device attribute CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS, the program makes a guarantee that it will only access the memory on the device from hStream
227                // > Accessing memory on the device from streams that are not associated with it will produce undefined results. No error checking is performed by the Unified Memory system to ensure that kernels launched into other streams do not access this region.
228                if self.stream.as_ref() != stream {
229                    return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED));
230                }
231            }
232        };
233        Ok(())
234    }
235}
236
237impl<T> DeviceSlice<T> for UnifiedSlice<T> {
238    fn len(&self) -> usize {
239        self.len
240    }
241    fn stream(&self) -> &Arc<CudaStream> {
242        &self.stream
243    }
244}
245
246impl<T> DevicePtr<T> for UnifiedSlice<T> {
247    fn device_ptr<'a>(
248        &'a self,
249        stream: &'a CudaStream,
250    ) -> (sys::CUdeviceptr, super::SyncOnDrop<'a>) {
251        stream.ctx.record_err(self.check_device_access(stream));
252        stream.ctx.record_err(stream.wait(&self.event));
253        (
254            self.cu_device_ptr,
255            super::SyncOnDrop::Record(Some((&self.event, stream))),
256        )
257    }
258}
259
260impl<T> DevicePtrMut<T> for UnifiedSlice<T> {
261    fn device_ptr_mut<'a>(
262        &'a mut self,
263        stream: &'a CudaStream,
264    ) -> (sys::CUdeviceptr, super::SyncOnDrop<'a>) {
265        stream.ctx.record_err(self.check_device_access(stream));
266        stream.ctx.record_err(stream.wait(&self.event));
267        (
268            self.cu_device_ptr,
269            super::SyncOnDrop::Record(Some((&self.event, stream))),
270        )
271    }
272}
273
274impl<T: ValidAsZeroBits> UnifiedSlice<T> {
275    /// Waits for any scheduled work to complete and then returns a refernce
276    /// to the host side data.
277    pub fn as_slice(&self) -> Result<&[T], DriverError> {
278        self.check_host_access()?;
279        self.event.synchronize()?;
280        Ok(unsafe { std::slice::from_raw_parts(self.cu_device_ptr as *const T, self.len) })
281    }
282
283    /// Waits for any scheduled work to complete and then returns a refernce
284    /// to the host side data.
285    pub fn as_mut_slice(&mut self) -> Result<&mut [T], DriverError> {
286        self.check_host_access()?;
287        self.event.synchronize()?;
288        Ok(unsafe { std::slice::from_raw_parts_mut(self.cu_device_ptr as *mut T, self.len) })
289    }
290}
291
292impl<T> HostSlice<T> for UnifiedSlice<T> {
293    fn len(&self) -> usize {
294        self.len
295    }
296    unsafe fn stream_synced_slice<'a>(
297        &'a self,
298        stream: &'a CudaStream,
299    ) -> (&'a [T], super::SyncOnDrop<'a>) {
300        stream.ctx.record_err(self.check_device_access(stream));
301        stream.ctx.record_err(stream.wait(&self.event));
302        (
303            std::slice::from_raw_parts(self.cu_device_ptr as *const T, self.len),
304            super::SyncOnDrop::Record(Some((&self.event, stream))),
305        )
306    }
307
308    unsafe fn stream_synced_mut_slice<'a>(
309        &'a mut self,
310        stream: &'a CudaStream,
311    ) -> (&'a mut [T], super::SyncOnDrop<'a>) {
312        stream.ctx.record_err(self.check_device_access(stream));
313        stream.ctx.record_err(stream.wait(&self.event));
314        (
315            std::slice::from_raw_parts_mut(self.cu_device_ptr as *mut T, self.len),
316            super::SyncOnDrop::Record(Some((&self.event, stream))),
317        )
318    }
319}
320
321unsafe impl<'a, 'b: 'a, T> PushKernelArg<&'b UnifiedSlice<T>> for LaunchArgs<'a> {
322    #[inline(always)]
323    fn arg(&mut self, arg: &'b UnifiedSlice<T>) -> &mut Self {
324        self.stream
325            .ctx
326            .record_err(arg.check_device_access(self.stream));
327        self.waits.push(&arg.event);
328        self.records.push(&arg.event);
329        self.args
330            .push((&arg.cu_device_ptr) as *const sys::CUdeviceptr as _);
331        self
332    }
333}
334
335unsafe impl<'a, 'b: 'a, T> PushKernelArg<&'b mut UnifiedSlice<T>> for LaunchArgs<'a> {
336    #[inline(always)]
337    fn arg(&mut self, arg: &'b mut UnifiedSlice<T>) -> &mut Self {
338        self.stream
339            .ctx
340            .record_err(arg.check_device_access(self.stream));
341        self.waits.push(&arg.event);
342        self.records.push(&arg.event);
343        self.args
344            .push((&arg.cu_device_ptr) as *const sys::CUdeviceptr as _);
345        self
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    #![allow(clippy::needless_range_loop)]
352
353    use crate::driver::{LaunchConfig, PushKernelArg};
354
355    use super::*;
356
357    #[test]
358    fn test_unified_memory_global() -> Result<(), DriverError> {
359        let ctx = CudaContext::new(0)?;
360
361        let mut a = unsafe { ctx.alloc_unified::<f32>(100, true) }?;
362        {
363            let buf = a.as_mut_slice()?;
364            for i in 0..100 {
365                buf[i] = i as f32;
366            }
367        }
368        {
369            let buf = a.as_slice()?;
370            for i in 0..100 {
371                assert_eq!(buf[i], i as f32);
372            }
373        }
374
375        let ptx = crate::nvrtc::compile_ptx(
376            "
377extern \"C\" __global__ void kernel(float *buf) {
378    if (threadIdx.x < 100) {
379        assert(buf[threadIdx.x] == static_cast<float>(threadIdx.x));
380    }
381}",
382        )
383        .unwrap();
384        let module = ctx.load_module(ptx)?;
385        let f = module.load_function("kernel")?;
386
387        let stream1 = ctx.default_stream();
388        unsafe {
389            stream1
390                .launch_builder(&f)
391                .arg(&mut a)
392                .launch(LaunchConfig::for_num_elems(100))
393        }?;
394        stream1.synchronize()?;
395
396        let stream2 = ctx.new_stream()?;
397        unsafe {
398            stream2
399                .launch_builder(&f)
400                .arg(&mut a)
401                .launch(LaunchConfig::for_num_elems(100))
402        }?;
403        stream2.synchronize()?;
404
405        {
406            let buf = a.as_slice()?;
407            for i in 0..100 {
408                assert_eq!(buf[i], i as f32);
409            }
410        }
411
412        // check usage as device ptr
413        let vs = stream1.memcpy_dtov(&a)?;
414        for i in 0..100 {
415            assert_eq!(vs[i], i as f32);
416        }
417
418        // check usage as host ptr
419        let b = stream1.memcpy_stod(&a)?;
420        let vs = stream1.memcpy_dtov(&b)?;
421        for i in 0..100 {
422            assert_eq!(vs[i], i as f32);
423        }
424
425        // check writing on device
426        stream1.memset_zeros(&mut a)?;
427        {
428            let buf = a.as_slice()?;
429            for i in 0..100 {
430                assert_eq!(buf[i], 0.0);
431            }
432        }
433
434        Ok(())
435    }
436
437    #[test]
438    fn test_unified_memory_host() -> Result<(), DriverError> {
439        let ctx = CudaContext::new(0)?;
440
441        let mut a = unsafe { ctx.alloc_unified::<f32>(100, false) }?;
442        {
443            let buf = a.as_mut_slice()?;
444            for i in 0..100 {
445                buf[i] = i as f32;
446            }
447        }
448        {
449            let buf = a.as_slice()?;
450            for i in 0..100 {
451                assert_eq!(buf[i], i as f32);
452            }
453        }
454
455        let ptx = crate::nvrtc::compile_ptx(
456            "
457extern \"C\" __global__ void kernel(float *buf) {
458    if (threadIdx.x < 100) {
459        assert(buf[threadIdx.x] == static_cast<float>(threadIdx.x));
460    }
461}",
462        )
463        .unwrap();
464        let module = ctx.load_module(ptx)?;
465        let f = module.load_function("kernel")?;
466
467        let stream1 = ctx.default_stream();
468        unsafe {
469            stream1
470                .launch_builder(&f)
471                .arg(&mut a)
472                .launch(LaunchConfig::for_num_elems(100))
473        }?;
474        stream1.synchronize()?;
475
476        let stream2 = ctx.new_stream()?;
477        unsafe {
478            stream2
479                .launch_builder(&f)
480                .arg(&mut a)
481                .launch(LaunchConfig::for_num_elems(100))
482        }?;
483        stream2.synchronize()?;
484
485        {
486            let buf = a.as_slice()?;
487            for i in 0..100 {
488                assert_eq!(buf[i], i as f32);
489            }
490        }
491
492        // check usage as device ptr
493        let vs = stream1.memcpy_dtov(&a)?;
494        for i in 0..100 {
495            assert_eq!(vs[i], i as f32);
496        }
497
498        // check usage as host ptr
499        let b = stream1.memcpy_stod(&a)?;
500        let vs = stream1.memcpy_dtov(&b)?;
501        for i in 0..100 {
502            assert_eq!(vs[i], i as f32);
503        }
504
505        // check writing on device
506        stream1.memset_zeros(&mut a)?;
507        {
508            let buf = a.as_slice()?;
509            for i in 0..100 {
510                assert_eq!(buf[i], 0.0);
511            }
512        }
513
514        Ok(())
515    }
516
517    #[test]
518    fn test_unified_memory_single_stream() -> Result<(), DriverError> {
519        let ctx = CudaContext::new(0)?;
520
521        let mut a = unsafe { ctx.alloc_unified::<f32>(100, true) }?;
522        {
523            let buf = a.as_mut_slice()?;
524            for i in 0..100 {
525                buf[i] = i as f32;
526            }
527        }
528        {
529            let buf = a.as_slice()?;
530            for i in 0..100 {
531                assert_eq!(buf[i], i as f32);
532            }
533        }
534
535        let ptx = crate::nvrtc::compile_ptx(
536            "
537extern \"C\" __global__ void kernel(float *buf) {
538    if (threadIdx.x < 100) {
539        assert(buf[threadIdx.x] == static_cast<float>(threadIdx.x));
540    }
541}",
542        )
543        .unwrap();
544        let module = ctx.load_module(ptx)?;
545        let f = module.load_function("kernel")?;
546
547        let stream2 = ctx.new_stream()?;
548        a.attach(&stream2, sys::CUmemAttach_flags::CU_MEM_ATTACH_SINGLE)?;
549        unsafe {
550            stream2
551                .launch_builder(&f)
552                .arg(&mut a)
553                .launch(LaunchConfig::for_num_elems(100))
554        }?;
555        stream2.synchronize()?;
556
557        let stream1 = ctx.default_stream();
558        unsafe {
559            stream1
560                .launch_builder(&f)
561                .arg(&mut a)
562                .launch(LaunchConfig::for_num_elems(100))
563        }
564        .expect_err("Other stream access should've failed");
565
566        {
567            let buf = a.as_slice()?;
568            for i in 0..100 {
569                assert_eq!(buf[i], i as f32);
570            }
571        }
572
573        // check usage as device ptr
574        let vs = stream2.memcpy_dtov(&a)?;
575        for i in 0..100 {
576            assert_eq!(vs[i], i as f32);
577        }
578
579        // check usage as host ptr
580        let b = stream2.memcpy_stod(&a)?;
581        let vs = stream2.memcpy_dtov(&b)?;
582        for i in 0..100 {
583            assert_eq!(vs[i], i as f32);
584        }
585
586        // check writing on device
587        stream2.memset_zeros(&mut a)?;
588        {
589            let buf = a.as_slice()?;
590            for i in 0..100 {
591                assert_eq!(buf[i], 0.0);
592            }
593        }
594
595        Ok(())
596    }
597}