1use anyhow::{anyhow, Result};
4
5#[derive(Debug)]
7pub struct GpuBuffer {
8 ptr: *mut f32,
9 size: usize,
10 device_id: i32,
11}
12
13unsafe impl Send for GpuBuffer {}
14unsafe impl Sync for GpuBuffer {}
15
16impl GpuBuffer {
17 pub fn new(size: usize, device_id: i32) -> Result<Self> {
18 let ptr = Self::allocate_gpu_memory(size * std::mem::size_of::<f32>(), device_id)?;
19 Ok(Self {
20 ptr: ptr as *mut f32,
21 size,
22 device_id,
23 })
24 }
25
26 pub fn copy_from_host(&mut self, data: &[f32]) -> Result<()> {
27 if data.len() > self.size {
28 return Err(anyhow!("Data size exceeds buffer capacity"));
29 }
30 self.copy_host_to_device(data.as_ptr(), self.ptr, std::mem::size_of_val(data))
31 }
32
33 pub fn copy_to_host(&self, data: &mut [f32]) -> Result<()> {
34 if data.len() > self.size {
35 return Err(anyhow!("Host buffer too small"));
36 }
37 self.copy_device_to_host(self.ptr, data.as_mut_ptr(), std::mem::size_of_val(data))
38 }
39
40 #[allow(unused_variables)]
41 fn allocate_gpu_memory(size: usize, device_id: i32) -> Result<*mut u8> {
42 #[cfg(feature = "cuda")]
45 {
46 use cuda_runtime_sys::*;
47 let mut ptr: *mut std::ffi::c_void = std::ptr::null_mut();
48 unsafe {
49 let result = cudaSetDevice(device_id);
50 if result != cudaError_t::cudaSuccess {
51 return Err(anyhow!("Failed to set CUDA device"));
52 }
53
54 let result = cudaMalloc(&mut ptr, size);
55 if result != cudaError_t::cudaSuccess {
56 return Err(anyhow!("Failed to allocate GPU memory"));
57 }
58 }
59 Ok(ptr as *mut u8)
60 }
61
62 #[cfg(not(feature = "cuda"))]
63 {
64 let layout = std::alloc::Layout::from_size_align(size, std::mem::align_of::<f32>())
66 .map_err(|e| anyhow!("Invalid memory layout: {}", e))?;
67 unsafe {
68 let ptr = std::alloc::alloc(layout);
69 if ptr.is_null() {
70 return Err(anyhow!("Failed to allocate memory"));
71 }
72 Ok(ptr)
73 }
74 }
75 }
76
77 fn copy_host_to_device(&self, src: *const f32, dst: *mut f32, size: usize) -> Result<()> {
78 #[cfg(feature = "cuda")]
79 {
80 use cuda_runtime_sys::*;
81 unsafe {
82 let result = cudaMemcpy(
83 dst as *mut std::ffi::c_void,
84 src as *const std::ffi::c_void,
85 size,
86 cudaMemcpyKind::cudaMemcpyHostToDevice,
87 );
88 if result != cudaError_t::cudaSuccess {
89 return Err(anyhow!("Failed to copy data to device"));
90 }
91 }
92 }
93
94 #[cfg(not(feature = "cuda"))]
95 {
96 unsafe {
98 std::ptr::copy_nonoverlapping(src, dst, size / std::mem::size_of::<f32>());
99 }
100 }
101 Ok(())
102 }
103
104 fn copy_device_to_host(&self, src: *const f32, dst: *mut f32, size: usize) -> Result<()> {
105 #[cfg(feature = "cuda")]
106 {
107 use cuda_runtime_sys::*;
108 unsafe {
109 let result = cudaMemcpy(
110 dst as *mut std::ffi::c_void,
111 src as *const std::ffi::c_void,
112 size,
113 cudaMemcpyKind::cudaMemcpyDeviceToHost,
114 );
115 if result != cudaError_t::cudaSuccess {
116 return Err(anyhow!("Failed to copy data from device"));
117 }
118 }
119 }
120
121 #[cfg(not(feature = "cuda"))]
122 {
123 unsafe {
125 std::ptr::copy_nonoverlapping(src, dst, size / std::mem::size_of::<f32>());
126 }
127 }
128 Ok(())
129 }
130
131 pub fn ptr(&self) -> *mut f32 {
132 self.ptr
133 }
134
135 pub fn size(&self) -> usize {
136 self.size
137 }
138
139 pub fn device_id(&self) -> i32 {
140 self.device_id
141 }
142
143 pub fn is_valid(&self) -> bool {
144 !self.ptr.is_null()
145 }
146
147 pub fn zero(&mut self) -> Result<()> {
149 #[cfg(feature = "cuda")]
150 {
151 use cuda_runtime_sys::*;
152 unsafe {
153 let result = cudaMemset(
154 self.ptr as *mut std::ffi::c_void,
155 0,
156 self.size * std::mem::size_of::<f32>(),
157 );
158 if result != cudaError_t::cudaSuccess {
159 return Err(anyhow!("Failed to zero buffer"));
160 }
161 }
162 }
163
164 #[cfg(not(feature = "cuda"))]
165 {
166 unsafe {
167 std::ptr::write_bytes(self.ptr, 0, self.size);
168 }
169 }
170 Ok(())
171 }
172}
173
174impl Drop for GpuBuffer {
175 fn drop(&mut self) {
176 if !self.ptr.is_null() {
177 #[cfg(feature = "cuda")]
178 {
179 use cuda_runtime_sys::*;
180 unsafe {
181 let _ = cudaFree(self.ptr as *mut std::ffi::c_void);
182 }
183 }
184
185 #[cfg(not(feature = "cuda"))]
186 {
187 let layout = std::alloc::Layout::from_size_align(
189 self.size * std::mem::size_of::<f32>(),
190 std::mem::align_of::<f32>(),
191 );
192 if let Ok(layout) = layout {
193 unsafe {
194 std::alloc::dealloc(self.ptr as *mut u8, layout);
195 }
196 }
197 }
198 }
199 }
200}