Skip to main content

nexar/memory/
buffer.rs

1//! Typed buffer wrappers that encode memory space in the type system.
2//!
3//! These are zero-cost wrappers around raw `u64` pointers. The type parameter
4//! prevents accidentally passing a host pointer where a device pointer is
5//! expected (and vice versa).
6//!
7//! The raw `u64` API remains for backward compatibility and FFI use cases.
8//! These wrappers are opt-in for users who want compile-time memory space safety.
9
10use std::marker::PhantomData;
11
12// ── Sealed trait pattern ─────────────────────────────────────────────
13
14mod private {
15    pub trait Sealed {}
16}
17
18/// Marker trait for memory spaces (host vs device).
19pub trait MemorySpace: private::Sealed {}
20
21/// Host (CPU) memory.
22pub enum Host {}
23impl private::Sealed for Host {}
24impl MemorySpace for Host {}
25
26/// Device (GPU) memory.
27pub enum Device {}
28impl private::Sealed for Device {}
29impl MemorySpace for Device {}
30
31// ── BufferPtr ────────────────────────────────────────────────────────
32
33/// A typed pointer to memory in a specific memory space.
34///
35/// Zero-cost wrapper around a raw `u64` pointer. The type parameter `S`
36/// prevents accidentally passing a host pointer where a device pointer
37/// is expected.
38#[repr(transparent)]
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct BufferPtr<S: MemorySpace> {
41    ptr: u64,
42    _space: PhantomData<S>,
43}
44
45impl<S: MemorySpace> BufferPtr<S> {
46    /// Wrap a raw `u64` pointer.
47    ///
48    /// # Safety
49    /// The pointer must actually point to memory in the space `S`.
50    pub unsafe fn new(ptr: u64) -> Self {
51        Self {
52            ptr,
53            _space: PhantomData,
54        }
55    }
56
57    /// Get the raw `u64` pointer.
58    pub fn as_u64(&self) -> u64 {
59        self.ptr
60    }
61}
62
63impl<S: MemorySpace> std::fmt::Display for BufferPtr<S> {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        write!(f, "BufferPtr(0x{:x})", self.ptr)
66    }
67}
68
69// ── BufferRef ────────────────────────────────────────────────────────
70
71/// A typed, sized buffer reference in a specific memory space.
72///
73/// Pairs a [`BufferPtr`] with a byte length, providing both type-level
74/// memory space safety and runtime size information.
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub struct BufferRef<S: MemorySpace> {
77    ptr: BufferPtr<S>,
78    len_bytes: usize,
79}
80
81impl<S: MemorySpace> BufferRef<S> {
82    /// Create a new buffer reference.
83    ///
84    /// # Safety
85    /// `ptr` must point to at least `len_bytes` of valid memory in space `S`.
86    pub unsafe fn new(ptr: u64, len_bytes: usize) -> Self {
87        Self {
88            ptr: unsafe { BufferPtr::new(ptr) },
89            len_bytes,
90        }
91    }
92
93    /// Get a reference to the typed pointer.
94    pub fn ptr(&self) -> &BufferPtr<S> {
95        &self.ptr
96    }
97
98    /// Size of the buffer in bytes.
99    pub fn len_bytes(&self) -> usize {
100        self.len_bytes
101    }
102
103    /// Returns true if the buffer has zero length.
104    pub fn is_empty(&self) -> bool {
105        self.len_bytes == 0
106    }
107
108    /// Get the raw `u64` pointer.
109    pub fn as_u64(&self) -> u64 {
110        self.ptr.as_u64()
111    }
112}
113
114impl<S: MemorySpace> std::fmt::Display for BufferRef<S> {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        write!(
117            f,
118            "BufferRef(0x{:x}, {}B)",
119            self.ptr.as_u64(),
120            self.len_bytes
121        )
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_buffer_ptr_host() {
131        let data: Vec<f32> = vec![1.0, 2.0, 3.0];
132        let ptr = unsafe { BufferPtr::<Host>::new(data.as_ptr() as u64) };
133        assert_eq!(ptr.as_u64(), data.as_ptr() as u64);
134    }
135
136    #[test]
137    fn test_buffer_ref_size() {
138        let data: Vec<u8> = vec![0; 1024];
139        let buf = unsafe { BufferRef::<Host>::new(data.as_ptr() as u64, 1024) };
140        assert_eq!(buf.len_bytes(), 1024);
141        assert!(!buf.is_empty());
142    }
143
144    #[test]
145    fn test_buffer_ref_empty() {
146        let buf = unsafe { BufferRef::<Host>::new(0, 0) };
147        assert!(buf.is_empty());
148    }
149
150    #[test]
151    fn test_display() {
152        let ptr = unsafe { BufferPtr::<Device>::new(0xDEAD) };
153        assert!(ptr.to_string().contains("0xdead"));
154
155        let buf = unsafe { BufferRef::<Host>::new(0xFF, 256) };
156        let s = buf.to_string();
157        assert!(s.contains("0xff"));
158        assert!(s.contains("256B"));
159    }
160
161    #[test]
162    fn test_type_safety_compiles() {
163        // This test verifies that Host and Device are distinct types.
164        // A function accepting BufferRef<Host> won't accept BufferRef<Device>.
165        fn _takes_host(_buf: &BufferRef<Host>) {}
166        fn _takes_device(_buf: &BufferRef<Device>) {}
167
168        let host_buf = unsafe { BufferRef::<Host>::new(0x1000, 64) };
169        let device_buf = unsafe { BufferRef::<Device>::new(0x2000, 64) };
170        _takes_host(&host_buf);
171        _takes_device(&device_buf);
172    }
173}