oxirs_vec/gpu/
buffer.rs

1//! GPU memory buffer management
2
3use anyhow::{anyhow, Result};
4
5/// GPU memory buffer for vector data
6#[derive(Debug)]
7pub struct GpuBuffer {
8    ptr: *mut f32,
9    size: usize,
10    device_id: i32,
11}
12
13unsafe impl Send for GpuBuffer {}
14unsafe impl Sync for GpuBuffer {}
15
16impl GpuBuffer {
17    pub fn new(size: usize, device_id: i32) -> Result<Self> {
18        let ptr = Self::allocate_gpu_memory(size * std::mem::size_of::<f32>(), device_id)?;
19        Ok(Self {
20            ptr: ptr as *mut f32,
21            size,
22            device_id,
23        })
24    }
25
26    pub fn copy_from_host(&mut self, data: &[f32]) -> Result<()> {
27        if data.len() > self.size {
28            return Err(anyhow!("Data size exceeds buffer capacity"));
29        }
30        self.copy_host_to_device(data.as_ptr(), self.ptr, std::mem::size_of_val(data))
31    }
32
33    pub fn copy_to_host(&self, data: &mut [f32]) -> Result<()> {
34        if data.len() > self.size {
35            return Err(anyhow!("Host buffer too small"));
36        }
37        self.copy_device_to_host(self.ptr, data.as_mut_ptr(), std::mem::size_of_val(data))
38    }
39
40    #[allow(unused_variables)]
41    fn allocate_gpu_memory(size: usize, device_id: i32) -> Result<*mut u8> {
42        // Simulate GPU memory allocation
43        // In a real implementation, this would use CUDA runtime API
44        #[cfg(feature = "cuda")]
45        {
46            use cuda_runtime_sys::*;
47            let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
48            unsafe {
49                let result = cudaSetDevice(device_id);
50                if result != cudaError_t::cudaSuccess {
51                    return Err(anyhow!("Failed to set CUDA device"));
52                }
53
54                let result = cudaMalloc(&mut ptr, size);
55                if result != cudaError_t::cudaSuccess {
56                    return Err(anyhow!("Failed to allocate GPU memory"));
57                }
58            }
59            Ok(ptr as *mut u8)
60        }
61
62        #[cfg(not(feature = "cuda"))]
63        {
64            // Fallback: allocate host memory for testing
65            let layout = std::alloc::Layout::from_size_align(size, std::mem::align_of::<f32>())
66                .map_err(|e| anyhow!("Invalid memory layout: {}", e))?;
67            unsafe {
68                let ptr = std::alloc::alloc(layout);
69                if ptr.is_null() {
70                    return Err(anyhow!("Failed to allocate memory"));
71                }
72                Ok(ptr)
73            }
74        }
75    }
76
77    fn copy_host_to_device(&self, src: *const f32, dst: *mut f32, size: usize) -> Result<()> {
78        #[cfg(feature = "cuda")]
79        {
80            use cuda_runtime_sys::*;
81            unsafe {
82                let result = cudaMemcpy(
83                    dst as *mut std::ffi::c_void,
84                    src as *const std::ffi::c_void,
85                    size,
86                    cudaMemcpyKind::cudaMemcpyHostToDevice,
87                );
88                if result != cudaError_t::cudaSuccess {
89                    return Err(anyhow!("Failed to copy data to device"));
90                }
91            }
92        }
93
94        #[cfg(not(feature = "cuda"))]
95        {
96            // Fallback: simple memory copy for testing
97            unsafe {
98                std::ptr::copy_nonoverlapping(src, dst, size / std::mem::size_of::<f32>());
99            }
100        }
101        Ok(())
102    }
103
104    fn copy_device_to_host(&self, src: *const f32, dst: *mut f32, size: usize) -> Result<()> {
105        #[cfg(feature = "cuda")]
106        {
107            use cuda_runtime_sys::*;
108            unsafe {
109                let result = cudaMemcpy(
110                    dst as *mut std::ffi::c_void,
111                    src as *const std::ffi::c_void,
112                    size,
113                    cudaMemcpyKind::cudaMemcpyDeviceToHost,
114                );
115                if result != cudaError_t::cudaSuccess {
116                    return Err(anyhow!("Failed to copy data from device"));
117                }
118            }
119        }
120
121        #[cfg(not(feature = "cuda"))]
122        {
123            // Fallback: simple memory copy for testing
124            unsafe {
125                std::ptr::copy_nonoverlapping(src, dst, size / std::mem::size_of::<f32>());
126            }
127        }
128        Ok(())
129    }
130
131    pub fn ptr(&self) -> *mut f32 {
132        self.ptr
133    }
134
135    pub fn size(&self) -> usize {
136        self.size
137    }
138
139    pub fn device_id(&self) -> i32 {
140        self.device_id
141    }
142
143    pub fn is_valid(&self) -> bool {
144        !self.ptr.is_null()
145    }
146
147    /// Zero out the buffer
148    pub fn zero(&mut self) -> Result<()> {
149        #[cfg(feature = "cuda")]
150        {
151            use cuda_runtime_sys::*;
152            unsafe {
153                let result = cudaMemset(
154                    self.ptr as *mut std::ffi::c_void,
155                    0,
156                    self.size * std::mem::size_of::<f32>(),
157                );
158                if result != cudaError_t::cudaSuccess {
159                    return Err(anyhow!("Failed to zero buffer"));
160                }
161            }
162        }
163
164        #[cfg(not(feature = "cuda"))]
165        {
166            unsafe {
167                std::ptr::write_bytes(self.ptr, 0, self.size);
168            }
169        }
170        Ok(())
171    }
172}
173
174impl Drop for GpuBuffer {
175    fn drop(&mut self) {
176        if !self.ptr.is_null() {
177            #[cfg(feature = "cuda")]
178            {
179                use cuda_runtime_sys::*;
180                unsafe {
181                    let _ = cudaFree(self.ptr as *mut std::ffi::c_void);
182                }
183            }
184
185            #[cfg(not(feature = "cuda"))]
186            {
187                // Fallback: deallocate host memory
188                let layout = std::alloc::Layout::from_size_align(
189                    self.size * std::mem::size_of::<f32>(),
190                    std::mem::align_of::<f32>(),
191                );
192                if let Ok(layout) = layout {
193                    unsafe {
194                        std::alloc::dealloc(self.ptr as *mut u8, layout);
195                    }
196                }
197            }
198        }
199    }
200}