cuda_rust_wasm/memory/
host_memory.rs1use crate::{Result, runtime_error};
4use std::alloc::{alloc, dealloc, Layout};
5use std::marker::PhantomData;
6use std::ptr::NonNull;
7
8pub struct HostBuffer<T> {
10 ptr: NonNull<T>,
11 len: usize,
12 layout: Layout,
13 phantom: PhantomData<T>,
14}
15
16impl<T: Copy> HostBuffer<T> {
17 pub fn new(len: usize) -> Result<Self> {
19 if len == 0 {
20 return Err(runtime_error!("Cannot allocate zero-length buffer"));
21 }
22
23 let size = len * std::mem::size_of::<T>();
24 let align = std::mem::align_of::<T>();
25
26 let layout = Layout::from_size_align(size, align)
27 .map_err(|e| runtime_error!("Invalid layout: {}", e))?;
28
29 unsafe {
30 let raw_ptr = alloc(layout);
31 if raw_ptr.is_null() {
32 return Err(runtime_error!(
33 "Failed to allocate {} bytes of host memory",
34 size
35 ));
36 }
37
38 let ptr = NonNull::new_unchecked(raw_ptr as *mut T);
39
40 Ok(Self {
41 ptr,
42 len,
43 layout,
44 phantom: PhantomData,
45 })
46 }
47 }
48
49 pub fn len(&self) -> usize {
51 self.len
52 }
53
54 pub fn is_empty(&self) -> bool {
56 self.len == 0
57 }
58
59 pub fn as_slice(&self) -> &[T] {
61 unsafe {
62 std::slice::from_raw_parts(self.ptr.as_ptr(), self.len)
63 }
64 }
65
66 pub fn as_mut_slice(&mut self) -> &mut [T] {
68 unsafe {
69 std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len)
70 }
71 }
72
73 pub fn copy_from_slice(&mut self, src: &[T]) -> Result<()> {
75 if src.len() != self.len {
76 return Err(runtime_error!(
77 "Source length {} doesn't match buffer length {}",
78 src.len(),
79 self.len
80 ));
81 }
82
83 self.as_mut_slice().copy_from_slice(src);
84 Ok(())
85 }
86
87 pub fn copy_to_slice(&self, dst: &mut [T]) -> Result<()> {
89 if dst.len() != self.len {
90 return Err(runtime_error!(
91 "Destination length {} doesn't match buffer length {}",
92 dst.len(),
93 self.len
94 ));
95 }
96
97 dst.copy_from_slice(self.as_slice());
98 Ok(())
99 }
100
101 pub fn fill(&mut self, value: T) {
103 for elem in self.as_mut_slice() {
104 *elem = value;
105 }
106 }
107}
108
109impl<T> Drop for HostBuffer<T> {
110 fn drop(&mut self) {
111 unsafe {
112 dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
113 }
114 }
115}
116
117impl<T: Copy> std::ops::Index<usize> for HostBuffer<T> {
119 type Output = T;
120
121 fn index(&self, index: usize) -> &Self::Output {
122 &self.as_slice()[index]
123 }
124}
125
126impl<T: Copy> std::ops::IndexMut<usize> for HostBuffer<T> {
127 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
128 &mut self.as_mut_slice()[index]
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn test_host_buffer_allocation() {
138 let buffer = HostBuffer::<f32>::new(1024).unwrap();
139 assert_eq!(buffer.len(), 1024);
140 assert!(!buffer.is_empty());
141 }
142
143 #[test]
144 fn test_host_buffer_copy() {
145 let mut buffer = HostBuffer::<i32>::new(10).unwrap();
146 let data: Vec<i32> = (0..10).collect();
147
148 buffer.copy_from_slice(&data).unwrap();
149
150 let mut result = vec![0; 10];
151 buffer.copy_to_slice(&mut result).unwrap();
152
153 assert_eq!(data, result);
154 }
155
156 #[test]
157 fn test_host_buffer_fill() {
158 let mut buffer = HostBuffer::<f64>::new(100).unwrap();
159 buffer.fill(3.14);
160
161 for i in 0..100 {
162 assert_eq!(buffer[i], 3.14);
163 }
164 }
165}