1use std::marker::PhantomData;
11
12mod private {
15 pub trait Sealed {}
16}
17
18pub trait MemorySpace: private::Sealed {}
20
21pub enum Host {}
23impl private::Sealed for Host {}
24impl MemorySpace for Host {}
25
26pub enum Device {}
28impl private::Sealed for Device {}
29impl MemorySpace for Device {}
30
31#[repr(transparent)]
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct BufferPtr<S: MemorySpace> {
41 ptr: u64,
42 _space: PhantomData<S>,
43}
44
45impl<S: MemorySpace> BufferPtr<S> {
46 pub unsafe fn new(ptr: u64) -> Self {
51 Self {
52 ptr,
53 _space: PhantomData,
54 }
55 }
56
57 pub fn as_u64(&self) -> u64 {
59 self.ptr
60 }
61}
62
63impl<S: MemorySpace> std::fmt::Display for BufferPtr<S> {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 write!(f, "BufferPtr(0x{:x})", self.ptr)
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub struct BufferRef<S: MemorySpace> {
77 ptr: BufferPtr<S>,
78 len_bytes: usize,
79}
80
81impl<S: MemorySpace> BufferRef<S> {
82 pub unsafe fn new(ptr: u64, len_bytes: usize) -> Self {
87 Self {
88 ptr: unsafe { BufferPtr::new(ptr) },
89 len_bytes,
90 }
91 }
92
93 pub fn ptr(&self) -> &BufferPtr<S> {
95 &self.ptr
96 }
97
98 pub fn len_bytes(&self) -> usize {
100 self.len_bytes
101 }
102
103 pub fn is_empty(&self) -> bool {
105 self.len_bytes == 0
106 }
107
108 pub fn as_u64(&self) -> u64 {
110 self.ptr.as_u64()
111 }
112}
113
114impl<S: MemorySpace> std::fmt::Display for BufferRef<S> {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 write!(
117 f,
118 "BufferRef(0x{:x}, {}B)",
119 self.ptr.as_u64(),
120 self.len_bytes
121 )
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_buffer_ptr_host() {
131 let data: Vec<f32> = vec![1.0, 2.0, 3.0];
132 let ptr = unsafe { BufferPtr::<Host>::new(data.as_ptr() as u64) };
133 assert_eq!(ptr.as_u64(), data.as_ptr() as u64);
134 }
135
136 #[test]
137 fn test_buffer_ref_size() {
138 let data: Vec<u8> = vec![0; 1024];
139 let buf = unsafe { BufferRef::<Host>::new(data.as_ptr() as u64, 1024) };
140 assert_eq!(buf.len_bytes(), 1024);
141 assert!(!buf.is_empty());
142 }
143
144 #[test]
145 fn test_buffer_ref_empty() {
146 let buf = unsafe { BufferRef::<Host>::new(0, 0) };
147 assert!(buf.is_empty());
148 }
149
150 #[test]
151 fn test_display() {
152 let ptr = unsafe { BufferPtr::<Device>::new(0xDEAD) };
153 assert!(ptr.to_string().contains("0xdead"));
154
155 let buf = unsafe { BufferRef::<Host>::new(0xFF, 256) };
156 let s = buf.to_string();
157 assert!(s.contains("0xff"));
158 assert!(s.contains("256B"));
159 }
160
161 #[test]
162 fn test_type_safety_compiles() {
163 fn _takes_host(_buf: &BufferRef<Host>) {}
166 fn _takes_device(_buf: &BufferRef<Device>) {}
167
168 let host_buf = unsafe { BufferRef::<Host>::new(0x1000, 64) };
169 let device_buf = unsafe { BufferRef::<Device>::new(0x2000, 64) };
170 _takes_host(&host_buf);
171 _takes_device(&device_buf);
172 }
173}