1use core::marker::PhantomData;
2use std::sync::Arc;
3
4use crate::driver::{result, sys};
5
6use super::{
7 CudaContext, CudaEvent, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr, DeviceSlice,
8 DriverError, HostSlice, LaunchArgs, PushKernelArg, ValidAsZeroBits,
9};
10
11#[derive(Debug)]
32pub struct UnifiedSlice<T> {
33 pub(crate) cu_device_ptr: sys::CUdeviceptr,
34 pub(crate) len: usize,
35 pub(crate) stream: Arc<CudaStream>,
36 pub(crate) event: CudaEvent,
37 pub(crate) attach_mode: sys::CUmemAttach_flags,
38 pub(crate) concurrent_managed_access: bool,
39 pub(crate) marker: PhantomData<*const T>,
40}
41
42unsafe impl<T> Send for UnifiedSlice<T> {}
43unsafe impl<T> Sync for UnifiedSlice<T> {}
44
45impl<T> Drop for UnifiedSlice<T> {
46 fn drop(&mut self) {
47 self.stream.ctx.record_err(self.event.synchronize());
48 self.stream
49 .ctx
50 .record_err(unsafe { result::memory_free(self.cu_device_ptr) });
51 }
52}
53
54impl CudaContext {
55 pub unsafe fn alloc_unified<T: DeviceRepr>(
70 self: &Arc<Self>,
71 len: usize,
72 attach_global: bool,
73 ) -> Result<UnifiedSlice<T>, DriverError> {
74 if self.attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY)? == 0 {
76 return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED));
77 }
78
79 let attach_mode = if attach_global {
80 sys::CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL
81 } else {
82 sys::CUmemAttach_flags::CU_MEM_ATTACH_HOST
83 };
84
85 let cu_device_ptr = result::malloc_managed(len * std::mem::size_of::<T>(), attach_mode)?;
86 let concurrent_managed_access = self
87 .attribute(sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS)?
88 != 0;
89
90 let stream = self.default_stream();
91 let event = self.new_event(Some(sys::CUevent_flags::CU_EVENT_BLOCKING_SYNC))?;
92
93 Ok(UnifiedSlice {
94 cu_device_ptr,
95 len,
96 stream,
97 event,
98 attach_mode,
99 concurrent_managed_access,
100 marker: PhantomData,
101 })
102 }
103}
104
105impl<T> UnifiedSlice<T> {
106 pub fn len(&self) -> usize {
107 self.len
108 }
109
110 pub fn is_empty(&self) -> bool {
111 self.len == 0
112 }
113
114 pub fn attach_mode(&self) -> sys::CUmemAttach_flags {
115 self.attach_mode
116 }
117
118 pub fn num_bytes(&self) -> usize {
119 self.len * std::mem::size_of::<T>()
120 }
121
122 pub fn attach(
126 &mut self,
127 stream: &Arc<CudaStream>,
128 flags: sys::CUmemAttach_flags,
129 ) -> Result<(), DriverError> {
130 self.event.synchronize()?;
131 self.stream = stream.clone();
132 self.attach_mode = flags;
133 unsafe {
134 result::stream::attach_mem_async(
135 self.stream.cu_stream,
136 self.cu_device_ptr,
137 self.num_bytes(),
138 self.attach_mode,
139 )
140 }
141 }
142
143 #[cfg(not(any(
145 feature = "cuda-11040",
146 feature = "cuda-11050",
147 feature = "cuda-11060",
148 feature = "cuda-11070",
149 feature = "cuda-11080",
150 feature = "cuda-12000",
151 feature = "cuda-12010"
152 )))]
153 pub fn prefetch(&self) -> Result<(), DriverError> {
154 let location = match self.attach_mode {
155 sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_GLOBAL
156 | sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_SINGLE => {
157 if !self.concurrent_managed_access {
159 return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED));
160 }
161 sys::CUmemLocation {
162 type_: sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_DEVICE,
163 id: self.stream.ctx.ordinal as i32,
164 }
165 }
166 sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_HOST => {
167 sys::CUmemLocation {
169 type_: sys::CUmemLocationType::CU_MEM_LOCATION_TYPE_HOST_NUMA_CURRENT,
170 id: 0, }
172 }
173 };
174 unsafe {
175 result::mem_prefetch_async(
176 self.cu_device_ptr,
177 self.len * std::mem::size_of::<T>(),
178 location,
179 self.stream.cu_stream,
180 )
181 }
182 }
183
184 pub fn check_host_access(&self) -> Result<(), DriverError> {
185 match self.attach_mode {
186 sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_GLOBAL => {
187 }
191 sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_HOST => {
192 }
195 sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_SINGLE => {
196 self.stream.synchronize()?;
198 }
199 };
200 Ok(())
201 }
202
203 pub fn check_device_access(&self, stream: &CudaStream) -> Result<(), DriverError> {
204 match self.attach_mode {
206 sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_GLOBAL => {
207 }
209 sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_HOST => {
210 let concurrent_managed_access = if self.stream.context() != stream.context() {
213 stream.context().attribute(
215 sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS,
216 )? != 0
217 } else {
218 self.concurrent_managed_access
220 };
221 if !concurrent_managed_access {
222 return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED));
223 }
224 }
225 sys::CUmemAttach_flags_enum::CU_MEM_ATTACH_SINGLE => {
226 if self.stream.as_ref() != stream {
229 return Err(DriverError(sys::cudaError_enum::CUDA_ERROR_NOT_PERMITTED));
230 }
231 }
232 };
233 Ok(())
234 }
235}
236
237impl<T> DeviceSlice<T> for UnifiedSlice<T> {
238 fn len(&self) -> usize {
239 self.len
240 }
241 fn stream(&self) -> &Arc<CudaStream> {
242 &self.stream
243 }
244}
245
246impl<T> DevicePtr<T> for UnifiedSlice<T> {
247 fn device_ptr<'a>(
248 &'a self,
249 stream: &'a CudaStream,
250 ) -> (sys::CUdeviceptr, super::SyncOnDrop<'a>) {
251 stream.ctx.record_err(self.check_device_access(stream));
252 stream.ctx.record_err(stream.wait(&self.event));
253 (
254 self.cu_device_ptr,
255 super::SyncOnDrop::Record(Some((&self.event, stream))),
256 )
257 }
258}
259
260impl<T> DevicePtrMut<T> for UnifiedSlice<T> {
261 fn device_ptr_mut<'a>(
262 &'a mut self,
263 stream: &'a CudaStream,
264 ) -> (sys::CUdeviceptr, super::SyncOnDrop<'a>) {
265 stream.ctx.record_err(self.check_device_access(stream));
266 stream.ctx.record_err(stream.wait(&self.event));
267 (
268 self.cu_device_ptr,
269 super::SyncOnDrop::Record(Some((&self.event, stream))),
270 )
271 }
272}
273
274impl<T: ValidAsZeroBits> UnifiedSlice<T> {
275 pub fn as_slice(&self) -> Result<&[T], DriverError> {
278 self.check_host_access()?;
279 self.event.synchronize()?;
280 Ok(unsafe { std::slice::from_raw_parts(self.cu_device_ptr as *const T, self.len) })
281 }
282
283 pub fn as_mut_slice(&mut self) -> Result<&mut [T], DriverError> {
286 self.check_host_access()?;
287 self.event.synchronize()?;
288 Ok(unsafe { std::slice::from_raw_parts_mut(self.cu_device_ptr as *mut T, self.len) })
289 }
290}
291
292impl<T> HostSlice<T> for UnifiedSlice<T> {
293 fn len(&self) -> usize {
294 self.len
295 }
296 unsafe fn stream_synced_slice<'a>(
297 &'a self,
298 stream: &'a CudaStream,
299 ) -> (&'a [T], super::SyncOnDrop<'a>) {
300 stream.ctx.record_err(self.check_device_access(stream));
301 stream.ctx.record_err(stream.wait(&self.event));
302 (
303 std::slice::from_raw_parts(self.cu_device_ptr as *const T, self.len),
304 super::SyncOnDrop::Record(Some((&self.event, stream))),
305 )
306 }
307
308 unsafe fn stream_synced_mut_slice<'a>(
309 &'a mut self,
310 stream: &'a CudaStream,
311 ) -> (&'a mut [T], super::SyncOnDrop<'a>) {
312 stream.ctx.record_err(self.check_device_access(stream));
313 stream.ctx.record_err(stream.wait(&self.event));
314 (
315 std::slice::from_raw_parts_mut(self.cu_device_ptr as *mut T, self.len),
316 super::SyncOnDrop::Record(Some((&self.event, stream))),
317 )
318 }
319}
320
321unsafe impl<'a, 'b: 'a, T> PushKernelArg<&'b UnifiedSlice<T>> for LaunchArgs<'a> {
322 #[inline(always)]
323 fn arg(&mut self, arg: &'b UnifiedSlice<T>) -> &mut Self {
324 self.stream
325 .ctx
326 .record_err(arg.check_device_access(self.stream));
327 self.waits.push(&arg.event);
328 self.records.push(&arg.event);
329 self.args
330 .push((&arg.cu_device_ptr) as *const sys::CUdeviceptr as _);
331 self
332 }
333}
334
335unsafe impl<'a, 'b: 'a, T> PushKernelArg<&'b mut UnifiedSlice<T>> for LaunchArgs<'a> {
336 #[inline(always)]
337 fn arg(&mut self, arg: &'b mut UnifiedSlice<T>) -> &mut Self {
338 self.stream
339 .ctx
340 .record_err(arg.check_device_access(self.stream));
341 self.waits.push(&arg.event);
342 self.records.push(&arg.event);
343 self.args
344 .push((&arg.cu_device_ptr) as *const sys::CUdeviceptr as _);
345 self
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 #![allow(clippy::needless_range_loop)]
352
353 use crate::driver::{LaunchConfig, PushKernelArg};
354
355 use super::*;
356
357 #[test]
358 fn test_unified_memory_global() -> Result<(), DriverError> {
359 let ctx = CudaContext::new(0)?;
360
361 let mut a = unsafe { ctx.alloc_unified::<f32>(100, true) }?;
362 {
363 let buf = a.as_mut_slice()?;
364 for i in 0..100 {
365 buf[i] = i as f32;
366 }
367 }
368 {
369 let buf = a.as_slice()?;
370 for i in 0..100 {
371 assert_eq!(buf[i], i as f32);
372 }
373 }
374
375 let ptx = crate::nvrtc::compile_ptx(
376 "
377extern \"C\" __global__ void kernel(float *buf) {
378 if (threadIdx.x < 100) {
379 assert(buf[threadIdx.x] == static_cast<float>(threadIdx.x));
380 }
381}",
382 )
383 .unwrap();
384 let module = ctx.load_module(ptx)?;
385 let f = module.load_function("kernel")?;
386
387 let stream1 = ctx.default_stream();
388 unsafe {
389 stream1
390 .launch_builder(&f)
391 .arg(&mut a)
392 .launch(LaunchConfig::for_num_elems(100))
393 }?;
394 stream1.synchronize()?;
395
396 let stream2 = ctx.new_stream()?;
397 unsafe {
398 stream2
399 .launch_builder(&f)
400 .arg(&mut a)
401 .launch(LaunchConfig::for_num_elems(100))
402 }?;
403 stream2.synchronize()?;
404
405 {
406 let buf = a.as_slice()?;
407 for i in 0..100 {
408 assert_eq!(buf[i], i as f32);
409 }
410 }
411
412 let vs = stream1.memcpy_dtov(&a)?;
414 for i in 0..100 {
415 assert_eq!(vs[i], i as f32);
416 }
417
418 let b = stream1.memcpy_stod(&a)?;
420 let vs = stream1.memcpy_dtov(&b)?;
421 for i in 0..100 {
422 assert_eq!(vs[i], i as f32);
423 }
424
425 stream1.memset_zeros(&mut a)?;
427 {
428 let buf = a.as_slice()?;
429 for i in 0..100 {
430 assert_eq!(buf[i], 0.0);
431 }
432 }
433
434 Ok(())
435 }
436
437 #[test]
438 fn test_unified_memory_host() -> Result<(), DriverError> {
439 let ctx = CudaContext::new(0)?;
440
441 let mut a = unsafe { ctx.alloc_unified::<f32>(100, false) }?;
442 {
443 let buf = a.as_mut_slice()?;
444 for i in 0..100 {
445 buf[i] = i as f32;
446 }
447 }
448 {
449 let buf = a.as_slice()?;
450 for i in 0..100 {
451 assert_eq!(buf[i], i as f32);
452 }
453 }
454
455 let ptx = crate::nvrtc::compile_ptx(
456 "
457extern \"C\" __global__ void kernel(float *buf) {
458 if (threadIdx.x < 100) {
459 assert(buf[threadIdx.x] == static_cast<float>(threadIdx.x));
460 }
461}",
462 )
463 .unwrap();
464 let module = ctx.load_module(ptx)?;
465 let f = module.load_function("kernel")?;
466
467 let stream1 = ctx.default_stream();
468 unsafe {
469 stream1
470 .launch_builder(&f)
471 .arg(&mut a)
472 .launch(LaunchConfig::for_num_elems(100))
473 }?;
474 stream1.synchronize()?;
475
476 let stream2 = ctx.new_stream()?;
477 unsafe {
478 stream2
479 .launch_builder(&f)
480 .arg(&mut a)
481 .launch(LaunchConfig::for_num_elems(100))
482 }?;
483 stream2.synchronize()?;
484
485 {
486 let buf = a.as_slice()?;
487 for i in 0..100 {
488 assert_eq!(buf[i], i as f32);
489 }
490 }
491
492 let vs = stream1.memcpy_dtov(&a)?;
494 for i in 0..100 {
495 assert_eq!(vs[i], i as f32);
496 }
497
498 let b = stream1.memcpy_stod(&a)?;
500 let vs = stream1.memcpy_dtov(&b)?;
501 for i in 0..100 {
502 assert_eq!(vs[i], i as f32);
503 }
504
505 stream1.memset_zeros(&mut a)?;
507 {
508 let buf = a.as_slice()?;
509 for i in 0..100 {
510 assert_eq!(buf[i], 0.0);
511 }
512 }
513
514 Ok(())
515 }
516
517 #[test]
518 fn test_unified_memory_single_stream() -> Result<(), DriverError> {
519 let ctx = CudaContext::new(0)?;
520
521 let mut a = unsafe { ctx.alloc_unified::<f32>(100, true) }?;
522 {
523 let buf = a.as_mut_slice()?;
524 for i in 0..100 {
525 buf[i] = i as f32;
526 }
527 }
528 {
529 let buf = a.as_slice()?;
530 for i in 0..100 {
531 assert_eq!(buf[i], i as f32);
532 }
533 }
534
535 let ptx = crate::nvrtc::compile_ptx(
536 "
537extern \"C\" __global__ void kernel(float *buf) {
538 if (threadIdx.x < 100) {
539 assert(buf[threadIdx.x] == static_cast<float>(threadIdx.x));
540 }
541}",
542 )
543 .unwrap();
544 let module = ctx.load_module(ptx)?;
545 let f = module.load_function("kernel")?;
546
547 let stream2 = ctx.new_stream()?;
548 a.attach(&stream2, sys::CUmemAttach_flags::CU_MEM_ATTACH_SINGLE)?;
549 unsafe {
550 stream2
551 .launch_builder(&f)
552 .arg(&mut a)
553 .launch(LaunchConfig::for_num_elems(100))
554 }?;
555 stream2.synchronize()?;
556
557 let stream1 = ctx.default_stream();
558 unsafe {
559 stream1
560 .launch_builder(&f)
561 .arg(&mut a)
562 .launch(LaunchConfig::for_num_elems(100))
563 }
564 .expect_err("Other stream access should've failed");
565
566 {
567 let buf = a.as_slice()?;
568 for i in 0..100 {
569 assert_eq!(buf[i], i as f32);
570 }
571 }
572
573 let vs = stream2.memcpy_dtov(&a)?;
575 for i in 0..100 {
576 assert_eq!(vs[i], i as f32);
577 }
578
579 let b = stream2.memcpy_stod(&a)?;
581 let vs = stream2.memcpy_dtov(&b)?;
582 for i in 0..100 {
583 assert_eq!(vs[i], i as f32);
584 }
585
586 stream2.memset_zeros(&mut a)?;
588 {
589 let buf = a.as_slice()?;
590 for i in 0..100 {
591 assert_eq!(buf[i], 0.0);
592 }
593 }
594
595 Ok(())
596 }
597}