Skip to main content

baracuda_driver/
memory.rs

1//! Device-memory types.
2//!
3//! - [`DeviceBuffer<T>`] — owned, typed GPU allocation.
4//! - [`DeviceSlice<'_, T>`] / [`DeviceSliceMut<'_, T>`] — non-owning views
5//!   into a buffer, with borrow-checker-tracked lifetimes.
6
7use core::ffi::c_void;
8use core::marker::PhantomData;
9use core::mem::size_of;
10use core::ops::Range;
11
12use baracuda_cuda_sys::{driver, CUdeviceptr};
13use baracuda_types::{DeviceRepr, KernelArg};
14
15use crate::context::Context;
16use crate::error::{check, Result};
17use crate::stream::Stream;
18
19/// Owned, typed allocation of device memory.
20///
21/// The underlying bytes are freed when the buffer drops. Clone/copy is
22/// deliberately *not* implemented — copying `len` bytes of device memory is
23/// not free, so baracuda makes the user spell it out as an explicit
24/// stream-ordered D2D memcpy.
25pub struct DeviceBuffer<T: DeviceRepr> {
26    ptr: CUdeviceptr,
27    len: usize,
28    context: Context,
29    _marker: PhantomData<T>,
30}
31
32// SAFETY: a device pointer can be moved between threads but concurrent
33// mutation requires external synchronization (streams).
34unsafe impl<T: DeviceRepr + Send> Send for DeviceBuffer<T> {}
35
36impl<T: DeviceRepr> core::fmt::Debug for DeviceBuffer<T> {
37    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
38        f.debug_struct("DeviceBuffer")
39            .field("ptr", &format_args!("{:#x}", self.ptr.0))
40            .field("len", &self.len)
41            .field("type", &core::any::type_name::<T>())
42            .finish()
43    }
44}
45
46impl<T: DeviceRepr> DeviceBuffer<T> {
47    /// Allocate an uninitialized buffer of `len` elements on the given context's device.
48    ///
49    /// `len == 0` (or a zero-sized `T`) short-circuits: CUDA rejects 0-byte
50    /// allocations with `CUDA_ERROR_INVALID_VALUE`, so we produce a sentinel
51    /// null-pointer buffer. [`Drop`] knows to skip the free on such buffers,
52    /// and every copy method below treats `len == 0` as a no-op.
53    pub fn new(context: &Context, len: usize) -> Result<Self> {
54        let bytes = len
55            .checked_mul(size_of::<T>())
56            .expect("overflow computing allocation size");
57        if bytes == 0 {
58            return Ok(Self {
59                ptr: CUdeviceptr(0),
60                len,
61                context: context.clone(),
62                _marker: PhantomData,
63            });
64        }
65        context.set_current()?;
66        let d = driver()?;
67        let cu = d.cu_mem_alloc()?;
68        let mut ptr = CUdeviceptr(0);
69        // SAFETY: `ptr` is writable; `bytes > 0` so cuMemAlloc is happy.
70        check(unsafe { cu(&mut ptr, bytes) })?;
71        Ok(Self {
72            ptr,
73            len,
74            context: context.clone(),
75            _marker: PhantomData,
76        })
77    }
78
79    /// Allocate `len` elements **asynchronously** on `stream` using the
80    /// device's default memory pool. Requires CUDA 11.2+.
81    ///
82    /// Unlike [`new`](Self::new), this call doesn't block — the
83    /// allocation becomes usable for any subsequent operation on `stream`
84    /// in stream order. Use [`free_async`](Self::free_async) to reclaim
85    /// on the same stream, or let `Drop` reclaim synchronously.
86    pub fn new_async(context: &Context, len: usize, stream: &Stream) -> Result<Self> {
87        let bytes = len
88            .checked_mul(size_of::<T>())
89            .expect("overflow computing allocation size");
90        if bytes == 0 {
91            return Ok(Self {
92                ptr: CUdeviceptr(0),
93                len,
94                context: context.clone(),
95                _marker: PhantomData,
96            });
97        }
98        context.set_current()?;
99        let d = driver()?;
100        let cu = d.cu_mem_alloc_async()?;
101        let mut ptr = CUdeviceptr(0);
102        // SAFETY: `ptr` writable; `stream` is live.
103        check(unsafe { cu(&mut ptr, bytes, stream.as_raw()) })?;
104        Ok(Self {
105            ptr,
106            len,
107            context: context.clone(),
108            _marker: PhantomData,
109        })
110    }
111
112    /// Free `self` asynchronously on `stream`. The buffer becomes invalid
113    /// stream-ordered-after this call completes on the device. Consumes
114    /// `self` so `Drop` does not also try to free.
115    ///
116    /// Requires CUDA 11.2+.
117    pub fn free_async(mut self, stream: &Stream) -> Result<()> {
118        let ptr = core::mem::replace(&mut self.ptr, CUdeviceptr(0));
119        if ptr.0 == 0 {
120            return Ok(());
121        }
122        let d = driver()?;
123        let cu = d.cu_mem_free_async()?;
124        check(unsafe { cu(ptr, stream.as_raw()) })
125    }
126
127    /// Allocate and fill with zero bytes. Zero-length allocations are a
128    /// no-op (no `cuMemsetD8` call is issued).
129    pub fn zeros(context: &Context, len: usize) -> Result<Self> {
130        let buf = Self::new(context, len)?;
131        let bytes = len * size_of::<T>();
132        if bytes == 0 {
133            return Ok(buf);
134        }
135        let d = driver()?;
136        let cu = d.cu_memset_d8()?;
137        check(unsafe { cu(buf.ptr, 0, bytes) })?;
138        Ok(buf)
139    }
140
141    /// Synchronously fill this buffer with zero bytes via `cuMemsetD8`.
142    /// Empty buffers are a no-op (no FFI call). Use this to reuse an
143    /// existing allocation when you want zeroed contents without paying
144    /// the allocation cost a second time.
145    pub fn zero(&self) -> Result<()> {
146        let bytes = self.len * size_of::<T>();
147        if bytes == 0 {
148            return Ok(());
149        }
150        let d = driver()?;
151        let cu = d.cu_memset_d8()?;
152        check(unsafe { cu(self.ptr, 0, bytes) })
153    }
154
155    /// Stream-ordered zero-fill via `cuMemsetD8Async`. Empty buffers are
156    /// a no-op. The fill is ordered with respect to other work submitted
157    /// to `stream`; synchronize the stream before reading from the host.
158    pub fn zero_async(&self, stream: &Stream) -> Result<()> {
159        let bytes = self.len * size_of::<T>();
160        if bytes == 0 {
161            return Ok(());
162        }
163        let d = driver()?;
164        let cu = d.cu_memset_d8_async()?;
165        check(unsafe { cu(self.ptr, 0, bytes, stream.as_raw()) })
166    }
167
168    /// Allocate and copy `src` synchronously from host memory. Empty
169    /// slices produce a sentinel zero-length buffer (no CUDA calls).
170    pub fn from_slice(context: &Context, src: &[T]) -> Result<Self> {
171        let buf = Self::new(context, src.len())?;
172        buf.copy_from_host(src)?;
173        Ok(buf)
174    }
175
176    /// Synchronous H2D copy. `src.len()` must equal `self.len()`.
177    /// No-op when the buffer is empty — no `cuMemcpy` is issued.
178    pub fn copy_from_host(&self, src: &[T]) -> Result<()> {
179        assert_eq!(
180            src.len(),
181            self.len,
182            "copy_from_host: source length {} != buffer length {}",
183            src.len(),
184            self.len
185        );
186        let bytes = self.len * size_of::<T>();
187        if bytes == 0 {
188            return Ok(());
189        }
190        let d = driver()?;
191        let cu = d.cu_memcpy_htod()?;
192        // SAFETY: `self.ptr` is a valid device pointer for `bytes` bytes;
193        // `src.as_ptr()` is valid for reads of `bytes` bytes.
194        check(unsafe { cu(self.ptr, src.as_ptr() as *const c_void, bytes) })
195    }
196
197    /// Synchronous D2H copy. `dst.len()` must equal `self.len()`.
198    /// No-op on empty buffers.
199    pub fn copy_to_host(&self, dst: &mut [T]) -> Result<()> {
200        assert_eq!(
201            dst.len(),
202            self.len,
203            "copy_to_host: destination length {} != buffer length {}",
204            dst.len(),
205            self.len
206        );
207        let bytes = self.len * size_of::<T>();
208        if bytes == 0 {
209            return Ok(());
210        }
211        let d = driver()?;
212        let cu = d.cu_memcpy_dtoh()?;
213        // SAFETY: mirror of `copy_from_host`; `dst` is valid for writes.
214        check(unsafe { cu(dst.as_mut_ptr() as *mut c_void, self.ptr, bytes) })
215    }
216
217    /// Asynchronous H2D copy on `stream`. No-op on empty buffers.
218    pub fn copy_from_host_async(&self, src: &[T], stream: &Stream) -> Result<()> {
219        assert_eq!(src.len(), self.len);
220        let bytes = self.len * size_of::<T>();
221        if bytes == 0 {
222            return Ok(());
223        }
224        let d = driver()?;
225        let cu = d.cu_memcpy_htod_async()?;
226        check(unsafe {
227            cu(
228                self.ptr,
229                src.as_ptr() as *const c_void,
230                bytes,
231                stream.as_raw(),
232            )
233        })
234    }
235
236    /// Asynchronous D2H copy on `stream`. No-op on empty buffers.
237    pub fn copy_to_host_async(&self, dst: &mut [T], stream: &Stream) -> Result<()> {
238        assert_eq!(dst.len(), self.len);
239        let bytes = self.len * size_of::<T>();
240        if bytes == 0 {
241            return Ok(());
242        }
243        let d = driver()?;
244        let cu = d.cu_memcpy_dtoh_async()?;
245        check(unsafe {
246            cu(
247                dst.as_mut_ptr() as *mut c_void,
248                self.ptr,
249                bytes,
250                stream.as_raw(),
251            )
252        })
253    }
254
255    /// Device-to-device copy into another buffer of the same length.
256    /// No-op on empty buffers.
257    pub fn copy_to_device(&self, dst: &DeviceBuffer<T>) -> Result<()> {
258        assert_eq!(dst.len, self.len);
259        let bytes = self.len * size_of::<T>();
260        if bytes == 0 {
261            return Ok(());
262        }
263        let d = driver()?;
264        let cu = d.cu_memcpy_dtod()?;
265        check(unsafe { cu(dst.ptr, self.ptr, bytes) })
266    }
267
268    /// Asynchronous device-to-device copy on `stream`. No-op on empty buffers.
269    pub fn copy_to_device_async(&self, dst: &DeviceBuffer<T>, stream: &Stream) -> Result<()> {
270        assert_eq!(dst.len, self.len);
271        let bytes = self.len * size_of::<T>();
272        if bytes == 0 {
273            return Ok(());
274        }
275        let d = driver()?;
276        let cu = d.cu_memcpy_dtod_async()?;
277        check(unsafe { cu(dst.ptr, self.ptr, bytes, stream.as_raw()) })
278    }
279
280    /// Number of elements in the buffer.
281    #[inline]
282    pub fn len(&self) -> usize {
283        self.len
284    }
285
286    /// Size of the buffer in bytes.
287    #[inline]
288    pub fn byte_size(&self) -> usize {
289        self.len * size_of::<T>()
290    }
291
292    /// `true` if the buffer has zero elements.
293    #[inline]
294    pub fn is_empty(&self) -> bool {
295        self.len == 0
296    }
297
298    /// The [`Context`] this buffer was allocated in.
299    #[inline]
300    pub fn context(&self) -> &Context {
301        &self.context
302    }
303
304    /// Raw device pointer. Use with care — baracuda still owns the allocation.
305    #[inline]
306    pub fn as_raw(&self) -> CUdeviceptr {
307        self.ptr
308    }
309
310    /// Borrow the whole buffer as a [`DeviceSlice<'_, T>`].
311    #[inline]
312    pub fn as_slice(&self) -> DeviceSlice<'_, T> {
313        DeviceSlice {
314            ptr: self.ptr,
315            len: self.len,
316            _marker: PhantomData,
317        }
318    }
319
320    /// Borrow the whole buffer as a [`DeviceSliceMut<'_, T>`].
321    #[inline]
322    pub fn as_slice_mut(&mut self) -> DeviceSliceMut<'_, T> {
323        DeviceSliceMut {
324            ptr: self.ptr,
325            len: self.len,
326            _marker: PhantomData,
327        }
328    }
329
330    /// Borrow a sub-range of the buffer as an immutable [`DeviceSlice`].
331    ///
332    /// Panics if the range is out of bounds or inverted. Element indices
333    /// are used — the byte offset is `range.start * size_of::<T>()`.
334    ///
335    /// ```no_run
336    /// # use baracuda_driver::{Context, Device, DeviceBuffer};
337    /// # fn demo() -> baracuda_driver::Result<()> {
338    /// let ctx = Context::new(&Device::get(0)?)?;
339    /// let buf: DeviceBuffer<f32> = DeviceBuffer::zeros(&ctx, 1024)?;
340    /// let first_half = buf.slice(0..512);
341    /// let tail = buf.slice(512..1024);
342    /// # let _ = (first_half, tail); Ok(()) }
343    /// ```
344    #[inline]
345    pub fn slice(&self, range: Range<usize>) -> DeviceSlice<'_, T> {
346        assert!(
347            range.start <= range.end && range.end <= self.len,
348            "DeviceBuffer::slice({}..{}) out of bounds for len {}",
349            range.start,
350            range.end,
351            self.len,
352        );
353        DeviceSlice {
354            ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
355            len: range.end - range.start,
356            _marker: PhantomData,
357        }
358    }
359
360    /// Mutable counterpart of [`slice`](Self::slice).
361    #[inline]
362    pub fn slice_mut(&mut self, range: Range<usize>) -> DeviceSliceMut<'_, T> {
363        assert!(
364            range.start <= range.end && range.end <= self.len,
365            "DeviceBuffer::slice_mut({}..{}) out of bounds for len {}",
366            range.start,
367            range.end,
368            self.len,
369        );
370        DeviceSliceMut {
371            ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
372            len: range.end - range.start,
373            _marker: PhantomData,
374        }
375    }
376}
377
378impl DeviceBuffer<u8> {
379    /// Reinterpret the byte buffer as an immutable typed [`DeviceSlice<'_, U>`].
380    ///
381    /// The recommended primitive for layering safe typed APIs over a
382    /// byte-shaped storage substrate — e.g. a unified-binding table that
383    /// stores all device tensors as `DeviceBuffer<u8>` and only acquires
384    /// element types at the edges where it calls into typed CUDA
385    /// libraries.
386    ///
387    /// Alignment is guaranteed: `cuMemAlloc` returns 256-byte-aligned
388    /// pointers, which satisfies any `U: DeviceRepr` we ship today and
389    /// any reasonable user type.
390    ///
391    /// # Panics
392    ///
393    /// Panics if the buffer's byte length isn't an integer multiple of
394    /// `size_of::<U>()`. Zero-sized `U` produces a zero-length view.
395    #[inline]
396    pub fn view_as<U: DeviceRepr>(&self) -> DeviceSlice<'_, U> {
397        let elem = size_of::<U>();
398        if elem == 0 {
399            return DeviceSlice {
400                ptr: self.ptr,
401                len: 0,
402                _marker: PhantomData,
403            };
404        }
405        assert!(
406            self.len % elem == 0,
407            "DeviceBuffer<u8>::view_as: byte length {} not divisible by size_of::<{}>() = {}",
408            self.len,
409            core::any::type_name::<U>(),
410            elem,
411        );
412        DeviceSlice {
413            ptr: self.ptr,
414            len: self.len / elem,
415            _marker: PhantomData,
416        }
417    }
418
419    /// Mutable counterpart of [`view_as`](Self::view_as).
420    #[inline]
421    pub fn view_as_mut<U: DeviceRepr>(&mut self) -> DeviceSliceMut<'_, U> {
422        let elem = size_of::<U>();
423        if elem == 0 {
424            return DeviceSliceMut {
425                ptr: self.ptr,
426                len: 0,
427                _marker: PhantomData,
428            };
429        }
430        assert!(
431            self.len % elem == 0,
432            "DeviceBuffer<u8>::view_as_mut: byte length {} not divisible by size_of::<{}>() = {}",
433            self.len,
434            core::any::type_name::<U>(),
435            elem,
436        );
437        DeviceSliceMut {
438            ptr: self.ptr,
439            len: self.len / elem,
440            _marker: PhantomData,
441        }
442    }
443}
444
445// ---- Unified (managed) memory --------------------------------------------
446
447/// Attach-mode for [`ManagedBuffer::new_with_flags`].
448#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
449pub enum ManagedAttach {
450    /// Accessible from any stream on any device. **Default.**
451    #[default]
452    Global,
453    /// Pinned to the host — accessible from the host, not from the GPU.
454    Host,
455    /// Accessible only on the stream it was later attached to.
456    Single,
457}
458
459impl ManagedAttach {
460    #[inline]
461    fn raw(self) -> u32 {
462        use baracuda_cuda_sys::types::CUmemAttach_flags as F;
463        match self {
464            ManagedAttach::Global => F::GLOBAL,
465            ManagedAttach::Host => F::HOST,
466            ManagedAttach::Single => F::SINGLE,
467        }
468    }
469}
470
471/// Memory-usage advice for `cuMemAdvise`.
472#[derive(Copy, Clone, Debug, Eq, PartialEq)]
473pub enum MemAdvise {
474    SetReadMostly,
475    UnsetReadMostly,
476    SetPreferredLocation,
477    UnsetPreferredLocation,
478    SetAccessedBy,
479    UnsetAccessedBy,
480}
481
482impl MemAdvise {
483    #[inline]
484    fn raw(self) -> i32 {
485        use baracuda_cuda_sys::types::CUmem_advise as A;
486        match self {
487            MemAdvise::SetReadMostly => A::SET_READ_MOSTLY,
488            MemAdvise::UnsetReadMostly => A::UNSET_READ_MOSTLY,
489            MemAdvise::SetPreferredLocation => A::SET_PREFERRED_LOCATION,
490            MemAdvise::UnsetPreferredLocation => A::UNSET_PREFERRED_LOCATION,
491            MemAdvise::SetAccessedBy => A::SET_ACCESSED_BY,
492            MemAdvise::UnsetAccessedBy => A::UNSET_ACCESSED_BY,
493        }
494    }
495}
496
497/// Owned allocation of **unified (managed) memory** — a single pointer that
498/// is accessible from both the host and the GPU, with on-demand migration
499/// handled by the driver. Compare with [`DeviceBuffer`], which is
500/// device-only and requires explicit memcpys.
501///
502/// On a discrete GPU, accessing the buffer from host code after a kernel
503/// finishes (and vice versa) requires a stream synchronize in the
504/// unified-memory model; [`ManagedBuffer::as_host_slice`] assumes the
505/// caller has done that.
506pub struct ManagedBuffer<T: DeviceRepr> {
507    ptr: CUdeviceptr,
508    len: usize,
509    context: Context,
510    _marker: PhantomData<T>,
511}
512
513unsafe impl<T: DeviceRepr + Send> Send for ManagedBuffer<T> {}
514
515impl<T: DeviceRepr> core::fmt::Debug for ManagedBuffer<T> {
516    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
517        f.debug_struct("ManagedBuffer")
518            .field("ptr", &format_args!("{:#x}", self.ptr.0))
519            .field("len", &self.len)
520            .field("type", &core::any::type_name::<T>())
521            .finish()
522    }
523}
524
525impl<T: DeviceRepr> ManagedBuffer<T> {
526    /// Allocate `len` elements of unified memory with the default
527    /// ([`ManagedAttach::Global`]) attach mode.
528    pub fn new(context: &Context, len: usize) -> Result<Self> {
529        Self::new_with_flags(context, len, ManagedAttach::Global)
530    }
531
532    /// Allocate with an explicit [`ManagedAttach`] mode.
533    pub fn new_with_flags(context: &Context, len: usize, attach: ManagedAttach) -> Result<Self> {
534        context.set_current()?;
535        let d = driver()?;
536        let cu = d.cu_mem_alloc_managed()?;
537        let bytes = len
538            .checked_mul(size_of::<T>())
539            .expect("overflow computing allocation size");
540        let mut ptr = CUdeviceptr(0);
541        // SAFETY: writable out-pointer; positive byte count.
542        check(unsafe { cu(&mut ptr, bytes, attach.raw()) })?;
543        Ok(Self {
544            ptr,
545            len,
546            context: context.clone(),
547            _marker: PhantomData,
548        })
549    }
550
551    /// Provide a hint to the Unified-Memory subsystem about how this range
552    /// will be accessed. `device` is the ordinal this advice targets (e.g.
553    /// the compute device for `SET_ACCESSED_BY`); pass the current device's
554    /// ordinal when in doubt.
555    pub fn advise(&self, advice: MemAdvise, device: &crate::Device) -> Result<()> {
556        let d = driver()?;
557        let cu = d.cu_mem_advise()?;
558        let bytes = self.len * size_of::<T>();
559        check(unsafe { cu(self.ptr, bytes, advice.raw(), device.as_raw()) })
560    }
561
562    /// Asynchronously prefetch this range to `device` on `stream`.
563    pub fn prefetch_async(&self, device: &crate::Device, stream: &Stream) -> Result<()> {
564        let d = driver()?;
565        let cu = d.cu_mem_prefetch_async()?;
566        let bytes = self.len * size_of::<T>();
567        check(unsafe { cu(self.ptr, bytes, device.as_raw(), stream.as_raw()) })
568    }
569
570    /// Access the buffer as a host slice. Safe to call on integrated GPUs or
571    /// after a synchronize on discrete GPUs; otherwise you'll see stale data.
572    ///
573    /// # Safety
574    ///
575    /// The caller must ensure:
576    /// 1. No concurrent kernel is writing to this buffer.
577    /// 2. On discrete GPUs, a relevant synchronize has been issued since
578    ///    the last device-side write.
579    pub unsafe fn as_host_slice(&self) -> &[T] { unsafe {
580        core::slice::from_raw_parts(self.ptr.0 as *const T, self.len)
581    }}
582
583    /// Mutable host view. Same safety rules as [`as_host_slice`](Self::as_host_slice).
584    ///
585    /// # Safety
586    ///
587    /// The caller must ensure no concurrent device or host access.
588    pub unsafe fn as_host_slice_mut(&mut self) -> &mut [T] { unsafe {
589        core::slice::from_raw_parts_mut(self.ptr.0 as *mut T, self.len)
590    }}
591
592    /// Number of elements.
593    #[inline]
594    pub fn len(&self) -> usize {
595        self.len
596    }
597
598    /// `true` if zero elements.
599    #[inline]
600    pub fn is_empty(&self) -> bool {
601        self.len == 0
602    }
603
604    /// Raw device pointer — the same value as the host pointer under UVM.
605    #[inline]
606    pub fn as_raw(&self) -> CUdeviceptr {
607        self.ptr
608    }
609
610    /// Owning context.
611    #[inline]
612    pub fn context(&self) -> &Context {
613        &self.context
614    }
615}
616
617impl<T: DeviceRepr> Drop for ManagedBuffer<T> {
618    fn drop(&mut self) {
619        if self.ptr.0 == 0 {
620            return;
621        }
622        if let Ok(d) = driver() {
623            if let Ok(cu) = d.cu_mem_free() {
624                let _ = unsafe { cu(self.ptr) };
625            }
626        }
627    }
628}
629
630/// Current device's free and total global memory, in bytes.
631///
632/// Requires a CUDA context to be current on the calling thread.
633pub fn mem_get_info() -> Result<(u64, u64)> {
634    let d = driver()?;
635    let cu = d.cu_mem_get_info()?;
636    let mut free: usize = 0;
637    let mut total: usize = 0;
638    check(unsafe { cu(&mut free, &mut total) })?;
639    Ok((free as u64, total as u64))
640}
641
642/// Peer-to-peer device memory copy between two contexts — the pointers
643/// must be valid device pointers in their respective contexts, and peer
644/// access must be enabled (see [`Context::enable_peer_access`]).
645pub fn memcpy_peer<T: DeviceRepr>(
646    dst: &DeviceBuffer<T>,
647    dst_ctx: &Context,
648    src: &DeviceBuffer<T>,
649    src_ctx: &Context,
650) -> Result<()> {
651    assert_eq!(dst.len(), src.len());
652    let d = driver()?;
653    let cu = d.cu_memcpy_peer()?;
654    let bytes = src.len() * size_of::<T>();
655    check(unsafe {
656        cu(
657            dst.as_raw(),
658            dst_ctx.as_raw(),
659            src.as_raw(),
660            src_ctx.as_raw(),
661            bytes,
662        )
663    })
664}
665
666/// Async peer-to-peer device memory copy ordered on `stream`.
667pub fn memcpy_peer_async<T: DeviceRepr>(
668    dst: &DeviceBuffer<T>,
669    dst_ctx: &Context,
670    src: &DeviceBuffer<T>,
671    src_ctx: &Context,
672    stream: &Stream,
673) -> Result<()> {
674    assert_eq!(dst.len(), src.len());
675    let d = driver()?;
676    let cu = d.cu_memcpy_peer_async()?;
677    let bytes = src.len() * size_of::<T>();
678    check(unsafe {
679        cu(
680            dst.as_raw(),
681            dst_ctx.as_raw(),
682            src.as_raw(),
683            src_ctx.as_raw(),
684            bytes,
685            stream.as_raw(),
686        )
687    })
688}
689
690/// Fill `count` 16-bit elements at `dst` with `value` (synchronous).
691pub fn memset_u16(dst: CUdeviceptr, value: u16, count: usize) -> Result<()> {
692    let d = driver()?;
693    let cu = d.cu_memset_d16()?;
694    check(unsafe { cu(dst, value, count) })
695}
696
697/// Async variant of [`memset_u16`] ordered on `stream`.
698pub fn memset_u16_async(dst: CUdeviceptr, value: u16, count: usize, stream: &Stream) -> Result<()> {
699    let d = driver()?;
700    let cu = d.cu_memset_d16_async()?;
701    check(unsafe { cu(dst, value, count, stream.as_raw()) })
702}
703
704/// Async 8-bit memset on `stream`.
705pub fn memset_u8_async(dst: CUdeviceptr, value: u8, count: usize, stream: &Stream) -> Result<()> {
706    let d = driver()?;
707    let cu = d.cu_memset_d8_async()?;
708    check(unsafe { cu(dst, value, count, stream.as_raw()) })
709}
710
711/// Async 32-bit memset on `stream`.
712pub fn memset_u32_async(dst: CUdeviceptr, value: u32, count: usize, stream: &Stream) -> Result<()> {
713    let d = driver()?;
714    let cu = d.cu_memset_d32_async()?;
715    check(unsafe { cu(dst, value, count, stream.as_raw()) })
716}
717
718/// Synchronous 32-bit memset.
719pub fn memset_u32(dst: CUdeviceptr, value: u32, count: usize) -> Result<()> {
720    let d = driver()?;
721    let cu = d.cu_memset_d32()?;
722    check(unsafe { cu(dst, value, count) })
723}
724
725/// 2D pitched 8-bit memset: fill a `width × height` rectangle of bytes
726/// at `dst` (row pitch `pitch`) with `value`.
727pub fn memset_2d_u8(
728    dst: CUdeviceptr,
729    pitch: usize,
730    value: u8,
731    width: usize,
732    height: usize,
733) -> Result<()> {
734    let d = driver()?;
735    let cu = d.cu_memset_d2d8()?;
736    check(unsafe { cu(dst, pitch, value, width, height) })
737}
738
739/// 2D pitched 16-bit memset.
740pub fn memset_2d_u16(
741    dst: CUdeviceptr,
742    pitch: usize,
743    value: u16,
744    width: usize,
745    height: usize,
746) -> Result<()> {
747    let d = driver()?;
748    let cu = d.cu_memset_d2d16()?;
749    check(unsafe { cu(dst, pitch, value, width, height) })
750}
751
752/// 2D pitched 32-bit memset.
753pub fn memset_2d_u32(
754    dst: CUdeviceptr,
755    pitch: usize,
756    value: u32,
757    width: usize,
758    height: usize,
759) -> Result<()> {
760    let d = driver()?;
761    let cu = d.cu_memset_d2d32()?;
762    check(unsafe { cu(dst, pitch, value, width, height) })
763}
764
765/// Generic byte-count copy that works on any pair of CUDA-addressable
766/// pointers (device, unified, host-pinned). Use this when the kind of
767/// memory at each end isn't known at the call site (typical for
768/// runtime-decided unified-memory paths).
769///
770/// # Safety
771///
772/// Both `dst` and `src` must be CUDA-addressable for at least `bytes`
773/// bytes; ranges must not overlap.
774pub unsafe fn memcpy(dst: CUdeviceptr, src: CUdeviceptr, bytes: usize) -> Result<()> { unsafe {
775    let d = driver()?;
776    let cu = d.cu_memcpy()?;
777    check(cu(dst, src, bytes))
778}}
779
780/// Async variant of [`memcpy`] ordered on `stream`.
781///
782/// # Safety
783///
784/// Same buffer rules as [`memcpy`]; additionally `dst` and `src` must
785/// stay valid until the work submitted on `stream` completes.
786pub unsafe fn memcpy_async(
787    dst: CUdeviceptr,
788    src: CUdeviceptr,
789    bytes: usize,
790    stream: &Stream,
791) -> Result<()> { unsafe {
792    let d = driver()?;
793    let cu = d.cu_memcpy_async()?;
794    check(cu(dst, src, bytes, stream.as_raw()))
795}}
796
797// ---- Wave 27: v2 advise/prefetch + VMM reverse lookups ------------------
798
799/// Destination for [`mem_prefetch_v2`] / [`mem_advise_v2`]. Composes the
800/// `CUmemLocation::{type_, id}` pair.
801#[derive(Copy, Clone, Debug, Eq, PartialEq)]
802pub enum PrefetchTarget {
803    /// Prefetch to a specific device.
804    Device(i32),
805    /// Prefetch to the host (unified memory).
806    Host,
807    /// Prefetch to a specific NUMA node on the host.
808    HostNuma(i32),
809    /// Prefetch to the current host thread's NUMA node.
810    HostNumaCurrent,
811}
812
813impl PrefetchTarget {
814    fn as_location(self) -> baracuda_cuda_sys::types::CUmemLocation {
815        use baracuda_cuda_sys::types::CUmemLocationType;
816        let (type_, id) = match self {
817            PrefetchTarget::Device(i) => (CUmemLocationType::DEVICE, i),
818            PrefetchTarget::Host => (CUmemLocationType::HOST, 0),
819            PrefetchTarget::HostNuma(n) => (CUmemLocationType::HOST_NUMA, n),
820            PrefetchTarget::HostNumaCurrent => (CUmemLocationType::HOST_NUMA_CURRENT, 0),
821        };
822        baracuda_cuda_sys::types::CUmemLocation { type_, id }
823    }
824}
825
826/// `cuMemPrefetchAsync_v2` — prefetch `count` bytes starting at `dptr` to
827/// the given [`PrefetchTarget`], ordered on `stream`.
828pub fn mem_prefetch_v2(
829    dptr: CUdeviceptr,
830    count: usize,
831    target: PrefetchTarget,
832    stream: &Stream,
833) -> Result<()> {
834    let d = driver()?;
835    let cu = d.cu_mem_prefetch_async_v2()?;
836    check(unsafe { cu(dptr, count, target.as_location(), 0, stream.as_raw()) })
837}
838
839/// `cuMemAdvise_v2` — unified-memory hint at a specific location.
840pub fn mem_advise_v2(
841    dptr: CUdeviceptr,
842    count: usize,
843    advice: i32,
844    target: PrefetchTarget,
845) -> Result<()> {
846    let d = driver()?;
847    let cu = d.cu_mem_advise_v2()?;
848    check(unsafe { cu(dptr, count, advice, target.as_location()) })
849}
850
851/// Reverse lookup: given a device pointer inside a VMM mapping, bump the
852/// underlying allocation handle's refcount and return it. Pair with
853/// `cuMemRelease` to drop the extra ref.
854pub fn retain_allocation_handle(
855    addr: CUdeviceptr,
856) -> Result<baracuda_cuda_sys::CUmemGenericAllocationHandle> {
857    let d = driver()?;
858    let cu = d.cu_mem_retain_allocation_handle()?;
859    let mut h: baracuda_cuda_sys::CUmemGenericAllocationHandle = 0;
860    check(unsafe { cu(&mut h, addr.0 as *mut core::ffi::c_void) })?;
861    Ok(h)
862}
863
864/// Query the creation props of an existing allocation handle.
865pub fn allocation_properties_from_handle(
866    handle: baracuda_cuda_sys::CUmemGenericAllocationHandle,
867) -> Result<baracuda_cuda_sys::types::CUmemAllocationProp> {
868    let d = driver()?;
869    let cu = d.cu_mem_get_allocation_properties_from_handle()?;
870    let mut prop = baracuda_cuda_sys::types::CUmemAllocationProp::default();
871    check(unsafe { cu(&mut prop, handle) })?;
872    Ok(prop)
873}
874
875/// Export an OS-level handle (e.g. a DMA-buf file descriptor on Linux)
876/// for a `size`-byte VA range starting at `dptr`.
877///
878/// # Safety
879///
880/// `handle_out` must point to a buffer appropriate for the `handle_type`
881/// (typically `*mut c_int` for `DMA_BUF_FD`). `dptr..dptr+size` must be a
882/// fully mapped VMM region.
883pub unsafe fn get_handle_for_address_range(
884    handle_out: *mut core::ffi::c_void,
885    dptr: CUdeviceptr,
886    size: usize,
887    handle_type: i32,
888) -> Result<()> { unsafe {
889    let d = driver()?;
890    let cu = d.cu_mem_get_handle_for_address_range()?;
891    check(cu(handle_out, dptr, size, handle_type, 0))
892}}
893
894impl<T: DeviceRepr> Drop for DeviceBuffer<T> {
895    fn drop(&mut self) {
896        if self.ptr.0 == 0 {
897            return;
898        }
899        if let Ok(d) = driver() {
900            if let Ok(cu) = d.cu_mem_free() {
901                let _ = unsafe { cu(self.ptr) };
902            }
903        }
904    }
905}
906
907/// Immutable view into a range of a [`DeviceBuffer`].
908#[derive(Copy, Clone)]
909pub struct DeviceSlice<'a, T: DeviceRepr> {
910    pub(crate) ptr: CUdeviceptr,
911    pub(crate) len: usize,
912    pub(crate) _marker: PhantomData<&'a T>,
913}
914
915impl<'a, T: DeviceRepr> core::fmt::Debug for DeviceSlice<'a, T> {
916    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
917        f.debug_struct("DeviceSlice")
918            .field("ptr", &format_args!("{:#x}", self.ptr.0))
919            .field("len", &self.len)
920            .finish()
921    }
922}
923
924impl<'a, T: DeviceRepr> DeviceSlice<'a, T> {
925    #[inline]
926    pub fn len(&self) -> usize {
927        self.len
928    }
929    #[inline]
930    pub fn is_empty(&self) -> bool {
931        self.len == 0
932    }
933    #[inline]
934    pub fn as_raw(&self) -> CUdeviceptr {
935        self.ptr
936    }
937
938    /// Construct a [`DeviceSlice`] from a raw device pointer and element count.
939    ///
940    /// The lower-level escape hatch backing
941    /// [`DeviceBuffer::view_as`](DeviceBuffer::view_as) — use this when
942    /// the source buffer is not a baracuda `DeviceBuffer` (e.g. a
943    /// pointer received from a foreign CUDA library, or a byte-shaped
944    /// storage substrate that erases the original buffer type at the
945    /// boundary).
946    ///
947    /// # Safety
948    ///
949    /// The caller must guarantee:
950    ///
951    /// 1. `ptr` points to at least `len * size_of::<T>()` bytes of
952    ///    device memory in the calling thread's current CUDA context.
953    /// 2. The pointer is properly aligned for `T`. `cuMemAlloc` returns
954    ///    256-byte-aligned pointers, but sub-slicing with a non-multiple
955    ///    element offset can reduce alignment.
956    /// 3. The pointed-to region remains live, with no concurrent
957    ///    mutation through any other path, for the lifetime `'b`.
958    /// 4. The contents are a valid bit-pattern for `T` (trivially true
959    ///    for `T: DeviceRepr` since all `DeviceRepr` types have no
960    ///    invalid bit patterns).
961    #[inline]
962    pub unsafe fn from_raw_parts<'b>(ptr: CUdeviceptr, len: usize) -> DeviceSlice<'b, T> {
963        DeviceSlice {
964            ptr,
965            len,
966            _marker: PhantomData,
967        }
968    }
969
970    /// Borrow a sub-range. Panics on out-of-bounds / inverted ranges.
971    #[inline]
972    pub fn slice(&self, range: Range<usize>) -> DeviceSlice<'_, T> {
973        assert!(
974            range.start <= range.end && range.end <= self.len,
975            "DeviceSlice::slice({}..{}) out of bounds for len {}",
976            range.start,
977            range.end,
978            self.len,
979        );
980        DeviceSlice {
981            ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
982            len: range.end - range.start,
983            _marker: PhantomData,
984        }
985    }
986}
987
988/// Mutable view into a range of a [`DeviceBuffer`].
989pub struct DeviceSliceMut<'a, T: DeviceRepr> {
990    pub(crate) ptr: CUdeviceptr,
991    pub(crate) len: usize,
992    pub(crate) _marker: PhantomData<&'a mut T>,
993}
994
995impl<'a, T: DeviceRepr> core::fmt::Debug for DeviceSliceMut<'a, T> {
996    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
997        f.debug_struct("DeviceSliceMut")
998            .field("ptr", &format_args!("{:#x}", self.ptr.0))
999            .field("len", &self.len)
1000            .finish()
1001    }
1002}
1003
1004impl<'a, T: DeviceRepr> DeviceSliceMut<'a, T> {
1005    #[inline]
1006    pub fn len(&self) -> usize {
1007        self.len
1008    }
1009    #[inline]
1010    pub fn is_empty(&self) -> bool {
1011        self.len == 0
1012    }
1013    #[inline]
1014    pub fn as_raw(&self) -> CUdeviceptr {
1015        self.ptr
1016    }
1017
1018    /// Construct a [`DeviceSliceMut`] from a raw device pointer and element count.
1019    ///
1020    /// Mutable counterpart of [`DeviceSlice::from_raw_parts`].
1021    ///
1022    /// # Safety
1023    ///
1024    /// All requirements of [`DeviceSlice::from_raw_parts`] apply, with
1025    /// the additional unique-access guarantee:
1026    ///
1027    /// - No other reference (mutable or shared) may alias the
1028    ///   pointed-to region for the lifetime `'b`. This includes any
1029    ///   `DeviceSlice<'_, U>` constructed from the same underlying
1030    ///   storage, even at a different element type.
1031    #[inline]
1032    pub unsafe fn from_raw_parts<'b>(ptr: CUdeviceptr, len: usize) -> DeviceSliceMut<'b, T> {
1033        DeviceSliceMut {
1034            ptr,
1035            len,
1036            _marker: PhantomData,
1037        }
1038    }
1039
1040    /// Borrow a sub-range immutably. Panics on out-of-bounds / inverted.
1041    #[inline]
1042    pub fn slice(&self, range: Range<usize>) -> DeviceSlice<'_, T> {
1043        assert!(
1044            range.start <= range.end && range.end <= self.len,
1045            "DeviceSliceMut::slice({}..{}) out of bounds for len {}",
1046            range.start,
1047            range.end,
1048            self.len,
1049        );
1050        DeviceSlice {
1051            ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
1052            len: range.end - range.start,
1053            _marker: PhantomData,
1054        }
1055    }
1056
1057    /// Borrow a sub-range mutably.
1058    #[inline]
1059    pub fn slice_mut(&mut self, range: Range<usize>) -> DeviceSliceMut<'_, T> {
1060        assert!(
1061            range.start <= range.end && range.end <= self.len,
1062            "DeviceSliceMut::slice_mut({}..{}) out of bounds for len {}",
1063            range.start,
1064            range.end,
1065            self.len,
1066        );
1067        DeviceSliceMut {
1068            ptr: CUdeviceptr(self.ptr.0 + (range.start * size_of::<T>()) as u64),
1069            len: range.end - range.start,
1070            _marker: PhantomData,
1071        }
1072    }
1073
1074    /// Asynchronous H2D copy on `stream`. Mirrors [`DeviceBuffer::copy_from_host_async`]
1075    /// for slice views — useful when the destination is a sub-range of a
1076    /// larger device buffer (e.g. packing CUTLASS grouped-GEMM metadata
1077    /// into a caller-supplied workspace).
1078    pub fn copy_from_host_async(&self, src: &[T], stream: &Stream) -> Result<()> {
1079        assert_eq!(src.len(), self.len);
1080        let bytes = self.len * size_of::<T>();
1081        if bytes == 0 {
1082            return Ok(());
1083        }
1084        let d = driver()?;
1085        let cu = d.cu_memcpy_htod_async()?;
1086        check(unsafe {
1087            cu(
1088                self.ptr,
1089                src.as_ptr() as *const c_void,
1090                bytes,
1091                stream.as_raw(),
1092            )
1093        })
1094    }
1095}
1096
1097// ============================================================================
1098// DevicePtr / DevicePtrMut — generic device-pointer trait surface
1099// ============================================================================
1100
1101/// Anything that can be read as a `[T]` on the device.
1102///
1103/// This is the abstraction over [`DeviceBuffer<T>`], [`DeviceSlice<'_, T>`],
1104/// and (via [`DevicePtrMut`]) [`DeviceSliceMut<'_, T>`] — letting generic code
1105/// accept any of them without fighting the type system.
1106///
1107/// Typical usage:
1108///
1109/// ```no_run
1110/// use baracuda_driver::{DevicePtr, DeviceBuffer, DeviceSlice};
1111/// use baracuda_types::DeviceRepr;
1112///
1113/// fn sum_elements<T: DeviceRepr, P: DevicePtr<T>>(buf: &P) -> usize {
1114///     // You get len() + device_ptr() for free; deeper ops go through
1115///     // the concrete type via `buf.device_ptr()`.
1116///     buf.len()
1117/// }
1118/// ```
1119///
1120/// # Safety
1121///
1122/// `device_ptr()` returns an opaque [`CUdeviceptr`]. Any dereference is
1123/// `unsafe` as always. Implementors must guarantee the pointer is live for
1124/// at least `len() * size_of::<T>()` bytes.
1125pub unsafe trait DevicePtr<T: DeviceRepr> {
1126    /// Raw device pointer to element 0.
1127    fn device_ptr(&self) -> CUdeviceptr;
1128
1129    /// Number of `T` elements visible through this pointer.
1130    fn len(&self) -> usize;
1131
1132    /// `true` if [`len`](Self::len) is 0.
1133    #[inline]
1134    fn is_empty(&self) -> bool {
1135        self.len() == 0
1136    }
1137
1138    /// Size in bytes (`len * size_of::<T>()`).
1139    #[inline]
1140    fn byte_size(&self) -> usize {
1141        self.len() * core::mem::size_of::<T>()
1142    }
1143}
1144
1145/// A [`DevicePtr`] that supports writes.
1146///
1147/// Implementors must hold a unique reference to the underlying storage for
1148/// the pointer's lifetime — e.g. `&mut DeviceBuffer<T>` or
1149/// [`DeviceSliceMut<'_, T>`]. This gives the trait the same borrow-checker
1150/// properties as `&mut [T]`.
1151///
1152/// # Safety
1153///
1154/// Implementors must guarantee that `device_ptr_mut` returns a pointer
1155/// that is unique for the duration of the `&mut self` borrow — no other
1156/// live pointer may alias it. Violating this lets concurrent kernels
1157/// observe writes in any order, breaking the borrow-checker contract
1158/// `&mut [T]` is meant to mirror.
1159pub unsafe trait DevicePtrMut<T: DeviceRepr>: DevicePtr<T> {
1160    /// Raw mutable device pointer.
1161    fn device_ptr_mut(&mut self) -> CUdeviceptr;
1162}
1163
1164// ---- Impls on the owned + borrowed device types -------------------------
1165
1166unsafe impl<T: DeviceRepr> DevicePtr<T> for DeviceBuffer<T> {
1167    #[inline]
1168    fn device_ptr(&self) -> CUdeviceptr {
1169        self.ptr
1170    }
1171    #[inline]
1172    fn len(&self) -> usize {
1173        self.len
1174    }
1175}
1176
1177unsafe impl<T: DeviceRepr> DevicePtrMut<T> for DeviceBuffer<T> {
1178    #[inline]
1179    fn device_ptr_mut(&mut self) -> CUdeviceptr {
1180        self.ptr
1181    }
1182}
1183
1184unsafe impl<'a, T: DeviceRepr> DevicePtr<T> for DeviceSlice<'a, T> {
1185    #[inline]
1186    fn device_ptr(&self) -> CUdeviceptr {
1187        self.ptr
1188    }
1189    #[inline]
1190    fn len(&self) -> usize {
1191        self.len
1192    }
1193}
1194
1195unsafe impl<'a, T: DeviceRepr> DevicePtr<T> for DeviceSliceMut<'a, T> {
1196    #[inline]
1197    fn device_ptr(&self) -> CUdeviceptr {
1198        self.ptr
1199    }
1200    #[inline]
1201    fn len(&self) -> usize {
1202        self.len
1203    }
1204}
1205
1206unsafe impl<'a, T: DeviceRepr> DevicePtrMut<T> for DeviceSliceMut<'a, T> {
1207    #[inline]
1208    fn device_ptr_mut(&mut self) -> CUdeviceptr {
1209        self.ptr
1210    }
1211}
1212
1213// References delegate transparently.
1214unsafe impl<T: DeviceRepr, P: DevicePtr<T> + ?Sized> DevicePtr<T> for &P {
1215    #[inline]
1216    fn device_ptr(&self) -> CUdeviceptr {
1217        (**self).device_ptr()
1218    }
1219    #[inline]
1220    fn len(&self) -> usize {
1221        (**self).len()
1222    }
1223}
1224
1225unsafe impl<T: DeviceRepr, P: DevicePtr<T> + ?Sized> DevicePtr<T> for &mut P {
1226    #[inline]
1227    fn device_ptr(&self) -> CUdeviceptr {
1228        (**self).device_ptr()
1229    }
1230    #[inline]
1231    fn len(&self) -> usize {
1232        (**self).len()
1233    }
1234}
1235
1236unsafe impl<T: DeviceRepr, P: DevicePtrMut<T> + ?Sized> DevicePtrMut<T> for &mut P {
1237    #[inline]
1238    fn device_ptr_mut(&mut self) -> CUdeviceptr {
1239        (**self).device_ptr_mut()
1240    }
1241}
1242
1243// ============================================================================
1244// KernelArg auto-marshalling for DeviceBuffer / DeviceSlice / DeviceSliceMut
1245// ============================================================================
1246//
1247// CUDA kernels receive device buffers as raw `T*` pointers, and
1248// `cuLaunchKernel` expects an array of `void**` — i.e. each argument slot
1249// must point to the pointer value. baracuda's DeviceBuffer/DeviceSlice
1250// already store a `CUdeviceptr` inline, so the safest thing is to return
1251// a pointer *into* the buffer/slice itself; the returned pointer stays
1252// valid as long as the `&DeviceBuffer` / `&DeviceSlice` reference does,
1253// which Rust's borrow checker already enforces for kernel launches.
1254
1255// SAFETY: `&self.ptr` points to a live `CUdeviceptr` owned by the
1256// DeviceBuffer; it remains valid for as long as the `&self` borrow does,
1257// which spans the kernel launch. CUDA reads the pointer value during
1258// submission and never writes it back through this slot.
1259unsafe impl<T: DeviceRepr> KernelArg for &DeviceBuffer<T> {
1260    #[inline]
1261    fn as_kernel_arg_ptr(&self) -> *mut c_void {
1262        &self.ptr as *const CUdeviceptr as *mut c_void
1263    }
1264}
1265
1266unsafe impl<T: DeviceRepr> KernelArg for &mut DeviceBuffer<T> {
1267    #[inline]
1268    fn as_kernel_arg_ptr(&self) -> *mut c_void {
1269        &self.ptr as *const CUdeviceptr as *mut c_void
1270    }
1271}
1272
1273unsafe impl<'a, T: DeviceRepr> KernelArg for &DeviceSlice<'a, T> {
1274    #[inline]
1275    fn as_kernel_arg_ptr(&self) -> *mut c_void {
1276        &self.ptr as *const CUdeviceptr as *mut c_void
1277    }
1278}
1279
1280unsafe impl<'a, T: DeviceRepr> KernelArg for &DeviceSliceMut<'a, T> {
1281    #[inline]
1282    fn as_kernel_arg_ptr(&self) -> *mut c_void {
1283        &self.ptr as *const CUdeviceptr as *mut c_void
1284    }
1285}
1286
1287unsafe impl<'a, T: DeviceRepr> KernelArg for &mut DeviceSliceMut<'a, T> {
1288    #[inline]
1289    fn as_kernel_arg_ptr(&self) -> *mut c_void {
1290        &self.ptr as *const CUdeviceptr as *mut c_void
1291    }
1292}
1293
1294#[cfg(test)]
1295mod slice_tests {
1296    //! Host-only: verify slice / slice_mut bounds math without a GPU.
1297    use super::*;
1298
1299    fn fake_slice<T: DeviceRepr>(ptr: u64, len: usize) -> DeviceSlice<'static, T> {
1300        DeviceSlice {
1301            ptr: CUdeviceptr(ptr),
1302            len,
1303            _marker: PhantomData,
1304        }
1305    }
1306
1307    #[test]
1308    fn slice_offsets_ptr_by_element_bytes() {
1309        let s: DeviceSlice<'_, f32> = fake_slice(0x1000, 16);
1310        let sub = s.slice(4..12);
1311        assert_eq!(sub.len(), 8);
1312        assert_eq!(sub.as_raw().0, 0x1000 + 4 * 4); // 4 elements * 4 bytes/f32
1313    }
1314
1315    #[test]
1316    fn slice_of_slice_stays_correct() {
1317        let s: DeviceSlice<'_, f64> = fake_slice(0x2000, 100);
1318        let mid = s.slice(10..90);
1319        let inner = mid.slice(5..15);
1320        assert_eq!(inner.len(), 10);
1321        // 10 + 5 = 15 from start, each f64 = 8 bytes.
1322        assert_eq!(inner.as_raw().0, 0x2000 + 15 * 8);
1323    }
1324
1325    #[test]
1326    #[should_panic(expected = "out of bounds")]
1327    fn slice_end_past_len_panics() {
1328        let s: DeviceSlice<'_, u8> = fake_slice(0, 10);
1329        let _ = s.slice(0..11);
1330    }
1331
1332    #[test]
1333    #[should_panic(expected = "out of bounds")]
1334    #[allow(clippy::reversed_empty_ranges)]
1335    // Intentionally inverted: this test verifies that `slice` rejects
1336    // start > end with the "out of bounds" panic message.
1337    fn slice_inverted_range_panics() {
1338        let s: DeviceSlice<'_, u8> = fake_slice(0, 10);
1339        let _ = s.slice(5..3);
1340    }
1341
1342    // ---- from_raw_parts -----------------------------------------------------
1343
1344    #[test]
1345    fn from_raw_parts_preserves_ptr_and_len() {
1346        // SAFETY: host-only inspection, no dereference. The pointer is
1347        // never read or written.
1348        let s: DeviceSlice<'static, f32> =
1349            unsafe { DeviceSlice::from_raw_parts(CUdeviceptr(0x4000), 32) };
1350        assert_eq!(s.as_raw().0, 0x4000);
1351        assert_eq!(s.len(), 32);
1352    }
1353
1354    #[test]
1355    fn from_raw_parts_mut_preserves_ptr_and_len() {
1356        let s: DeviceSliceMut<'static, u32> =
1357            unsafe { DeviceSliceMut::from_raw_parts(CUdeviceptr(0x8000), 64) };
1358        assert_eq!(s.as_raw().0, 0x8000);
1359        assert_eq!(s.len(), 64);
1360    }
1361}
1362
1363#[cfg(test)]
1364mod kernel_arg_tests {
1365    //! Host-only: verify the returned pointer points to the CUdeviceptr
1366    //! actually stored inside the buffer/slice, so kernels see the right
1367    //! device address. We don't need a GPU to check this — we fabricate
1368    //! a DeviceSlice with PhantomData and inspect its bytes.
1369
1370    use super::*;
1371    use core::mem::size_of;
1372
1373    #[test]
1374    fn slice_kernel_arg_points_at_ptr_field() {
1375        let slice: DeviceSlice<'_, f32> = DeviceSlice {
1376            ptr: CUdeviceptr(0xDEAD_BEEF_u64),
1377            len: 42,
1378            _marker: PhantomData,
1379        };
1380        let kernel_arg = (&slice).as_kernel_arg_ptr();
1381        // The returned pointer should point to a u64 = 0xDEADBEEF.
1382        unsafe {
1383            let as_u64 = *(kernel_arg as *const u64);
1384            assert_eq!(as_u64, 0xDEAD_BEEF);
1385        }
1386        // And the pointer must live inside the slice struct itself.
1387        let slice_start = &slice as *const _ as usize;
1388        let slice_end = slice_start + size_of::<DeviceSlice<'_, f32>>();
1389        let arg_addr = kernel_arg as usize;
1390        assert!((slice_start..slice_end).contains(&arg_addr));
1391    }
1392}