cuda_rust_wasm/memory/
host_memory.rs

1//! Host (CPU) memory management
2
3use crate::{Result, runtime_error};
4use std::alloc::{alloc, dealloc, Layout};
5use std::marker::PhantomData;
6use std::ptr::NonNull;
7
8/// Page-locked host memory for efficient transfers
9pub struct HostBuffer<T> {
10    ptr: NonNull<T>,
11    len: usize,
12    layout: Layout,
13    phantom: PhantomData<T>,
14}
15
16impl<T: Copy> HostBuffer<T> {
17    /// Allocate a new pinned host buffer
18    pub fn new(len: usize) -> Result<Self> {
19        if len == 0 {
20            return Err(runtime_error!("Cannot allocate zero-length buffer"));
21        }
22        
23        let size = len * std::mem::size_of::<T>();
24        let align = std::mem::align_of::<T>();
25        
26        let layout = Layout::from_size_align(size, align)
27            .map_err(|e| runtime_error!("Invalid layout: {}", e))?;
28        
29        unsafe {
30            let raw_ptr = alloc(layout);
31            if raw_ptr.is_null() {
32                return Err(runtime_error!(
33                    "Failed to allocate {} bytes of host memory",
34                    size
35                ));
36            }
37            
38            let ptr = NonNull::new_unchecked(raw_ptr as *mut T);
39            
40            Ok(Self {
41                ptr,
42                len,
43                layout,
44                phantom: PhantomData,
45            })
46        }
47    }
48    
49    /// Get buffer length
50    pub fn len(&self) -> usize {
51        self.len
52    }
53    
54    /// Check if buffer is empty
55    pub fn is_empty(&self) -> bool {
56        self.len == 0
57    }
58    
59    /// Get a slice view of the buffer
60    pub fn as_slice(&self) -> &[T] {
61        unsafe {
62            std::slice::from_raw_parts(self.ptr.as_ptr(), self.len)
63        }
64    }
65    
66    /// Get a mutable slice view of the buffer
67    pub fn as_mut_slice(&mut self) -> &mut [T] {
68        unsafe {
69            std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len)
70        }
71    }
72    
73    /// Copy from a slice
74    pub fn copy_from_slice(&mut self, src: &[T]) -> Result<()> {
75        if src.len() != self.len {
76            return Err(runtime_error!(
77                "Source length {} doesn't match buffer length {}",
78                src.len(),
79                self.len
80            ));
81        }
82        
83        self.as_mut_slice().copy_from_slice(src);
84        Ok(())
85    }
86    
87    /// Copy to a slice
88    pub fn copy_to_slice(&self, dst: &mut [T]) -> Result<()> {
89        if dst.len() != self.len {
90            return Err(runtime_error!(
91                "Destination length {} doesn't match buffer length {}",
92                dst.len(),
93                self.len
94            ));
95        }
96        
97        dst.copy_from_slice(self.as_slice());
98        Ok(())
99    }
100    
101    /// Fill buffer with a value
102    pub fn fill(&mut self, value: T) {
103        for elem in self.as_mut_slice() {
104            *elem = value;
105        }
106    }
107}
108
109impl<T> Drop for HostBuffer<T> {
110    fn drop(&mut self) {
111        unsafe {
112            dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
113        }
114    }
115}
116
117// Implement Index traits for convenient access
118impl<T: Copy> std::ops::Index<usize> for HostBuffer<T> {
119    type Output = T;
120    
121    fn index(&self, index: usize) -> &Self::Output {
122        &self.as_slice()[index]
123    }
124}
125
126impl<T: Copy> std::ops::IndexMut<usize> for HostBuffer<T> {
127    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
128        &mut self.as_mut_slice()[index]
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    
136    #[test]
137    fn test_host_buffer_allocation() {
138        let buffer = HostBuffer::<f32>::new(1024).unwrap();
139        assert_eq!(buffer.len(), 1024);
140        assert!(!buffer.is_empty());
141    }
142    
143    #[test]
144    fn test_host_buffer_copy() {
145        let mut buffer = HostBuffer::<i32>::new(10).unwrap();
146        let data: Vec<i32> = (0..10).collect();
147        
148        buffer.copy_from_slice(&data).unwrap();
149        
150        let mut result = vec![0; 10];
151        buffer.copy_to_slice(&mut result).unwrap();
152        
153        assert_eq!(data, result);
154    }
155    
156    #[test]
157    fn test_host_buffer_fill() {
158        let mut buffer = HostBuffer::<f64>::new(100).unwrap();
159        buffer.fill(3.14);
160        
161        for i in 0..100 {
162            assert_eq!(buffer[i], 3.14);
163        }
164    }
165}