1use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7
8#[derive(Debug)]
10pub struct GpuMemoryPool {
11 free_blocks: Arc<Mutex<HashMap<usize, Vec<*mut u8>>>>,
13 allocated: Arc<Mutex<HashMap<*mut u8, usize>>>,
15 total_allocated: Arc<Mutex<usize>>,
17 peak_usage: Arc<Mutex<usize>>,
19}
20
21impl GpuMemoryPool {
22 pub fn new() -> Self {
24 Self {
25 free_blocks: Arc::new(Mutex::new(HashMap::new())),
26 allocated: Arc::new(Mutex::new(HashMap::new())),
27 total_allocated: Arc::new(Mutex::new(0)),
28 peak_usage: Arc::new(Mutex::new(0)),
29 }
30 }
31
32 pub fn allocate(&self, size: usize) -> Result<*mut u8, String> {
34 let rounded_size = size.next_power_of_two();
36
37 let mut free_blocks = self.free_blocks.lock().unwrap();
39 if let Some(blocks) = free_blocks.get_mut(&rounded_size) {
40 if let Some(ptr) = blocks.pop() {
41 let mut allocated = self.allocated.lock().unwrap();
43 allocated.insert(ptr, rounded_size);
44 return Ok(ptr);
45 }
46 }
47 drop(free_blocks);
48
49 #[cfg(feature = "cuda")]
51 let ptr = unsafe {
52 let mut ptr: *mut u8 = std::ptr::null_mut();
53 use crate::ffi;
54 let result = ffi::cudaMalloc(
55 &mut ptr as *mut *mut u8 as *mut *mut std::ffi::c_void,
56 rounded_size,
57 );
58 if result != 0 {
59 return Err(format!("CUDA malloc failed with error code: {}", result));
60 }
61 ptr
62 };
63
64 #[cfg(not(feature = "cuda"))]
65 let ptr = {
66 vec![0u8; rounded_size].leak().as_mut_ptr()
68 };
69
70 let mut allocated = self.allocated.lock().unwrap();
72 allocated.insert(ptr, rounded_size);
73
74 let mut total = self.total_allocated.lock().unwrap();
75 *total += rounded_size;
76
77 let mut peak = self.peak_usage.lock().unwrap();
78 if *total > *peak {
79 *peak = *total;
80 }
81
82 Ok(ptr)
83 }
84
85 pub fn free(&self, ptr: *mut u8) -> Result<(), String> {
87 let mut allocated = self.allocated.lock().unwrap();
88
89 if let Some(size) = allocated.remove(&ptr) {
90 let mut free_blocks = self.free_blocks.lock().unwrap();
92 free_blocks.entry(size).or_insert_with(Vec::new).push(ptr);
93
94 let mut total = self.total_allocated.lock().unwrap();
95 *total -= size;
96
97 Ok(())
98 } else {
99 Err("Attempted to free unallocated pointer".to_string())
100 }
101 }
102
103 pub fn clear(&self) -> Result<(), String> {
105 let mut free_blocks = self.free_blocks.lock().unwrap();
106
107 for (_, blocks) in free_blocks.iter() {
108 for &ptr in blocks {
109 #[cfg(feature = "cuda")]
110 unsafe {
111 use crate::ffi;
112 ffi::cudaFree(ptr as *mut std::ffi::c_void);
113 }
114
115 #[cfg(not(feature = "cuda"))]
116 unsafe {
117 let _ = Vec::from_raw_parts(ptr, 0, 0);
119 }
120 }
121 }
122
123 free_blocks.clear();
124 Ok(())
125 }
126
127 pub fn current_usage(&self) -> usize {
129 *self.total_allocated.lock().unwrap()
130 }
131
132 pub fn peak_usage(&self) -> usize {
134 *self.peak_usage.lock().unwrap()
135 }
136}
137
138impl Drop for GpuMemoryPool {
139 fn drop(&mut self) {
140 #[cfg(feature = "cuda")]
141 {
142 let _ = self.clear();
143 }
144 }
145}
146
147pub struct GpuTensor {
149 ptr: *mut f32,
151 shape: Vec<usize>,
153 #[allow(dead_code)]
155 pool: Arc<GpuMemoryPool>,
156}
157
158impl GpuTensor {
159 #[cfg(feature = "cuda")]
161 pub fn new(shape: Vec<usize>, pool: Arc<GpuMemoryPool>) -> Result<Self, String> {
162 let numel: usize = shape.iter().product();
163 let size_bytes = numel * std::mem::size_of::<f32>();
164
165 let ptr = pool.allocate(size_bytes)? as *mut f32;
166
167 Ok(Self { ptr, shape, pool })
168 }
169
170 pub fn copy_from_cpu(&mut self, data: &[f32]) -> Result<(), String> {
172 let numel: usize = self.shape.iter().product();
173 if data.len() != numel {
174 return Err("Data size mismatch".to_string());
175 }
176
177 #[cfg(feature = "cuda")]
178 unsafe {
179 use crate::ffi;
180 let result = ffi::cudaMemcpy(
181 self.ptr as *mut std::ffi::c_void,
182 data.as_ptr() as *const std::ffi::c_void,
183 numel * std::mem::size_of::<f32>(),
184 ffi::cudaMemcpyKind::cudaMemcpyHostToDevice,
185 );
186 if result != 0 {
187 return Err(format!("CUDA memcpy H2D failed with error code: {}", result));
188 }
189 }
190
191 #[cfg(not(feature = "cuda"))]
192 unsafe {
193 std::ptr::copy_nonoverlapping(data.as_ptr(), self.ptr, numel);
195 }
196
197 Ok(())
198 }
199
200 pub fn copy_to_cpu(&self) -> Result<Vec<f32>, String> {
202 let numel: usize = self.shape.iter().product();
203 let mut data = vec![0.0f32; numel];
204
205 #[cfg(feature = "cuda")]
206 unsafe {
207 use crate::ffi;
208 let result = ffi::cudaMemcpy(
209 data.as_mut_ptr() as *mut std::ffi::c_void,
210 self.ptr as *const std::ffi::c_void,
211 numel * std::mem::size_of::<f32>(),
212 ffi::cudaMemcpyKind::cudaMemcpyDeviceToHost,
213 );
214 if result != 0 {
215 return Err(format!("CUDA memcpy D2H failed with error code: {}", result));
216 }
217 }
218
219 #[cfg(not(feature = "cuda"))]
220 unsafe {
221 std::ptr::copy_nonoverlapping(self.ptr, data.as_mut_ptr(), numel);
223 }
224
225 Ok(data)
226 }
227
228 pub fn as_ptr(&self) -> *const f32 {
230 self.ptr
231 }
232
233 pub fn as_mut_ptr(&mut self) -> *mut f32 {
235 self.ptr
236 }
237
238 pub fn shape(&self) -> &[usize] {
240 &self.shape
241 }
242}
243
244impl Drop for GpuTensor {
245 fn drop(&mut self) {
246 #[cfg(feature = "cuda")]
247 {
248 let _ = self.pool.free(self.ptr as *mut u8);
249 }
250 }
251}
252
253static mut GLOBAL_GPU_POOL: Option<Arc<GpuMemoryPool>> = None;
255
256#[allow(static_mut_refs)]
258pub fn get_global_gpu_pool() -> Arc<GpuMemoryPool> {
259 unsafe {
260 if GLOBAL_GPU_POOL.is_none() {
261 GLOBAL_GPU_POOL = Some(Arc::new(GpuMemoryPool::new()));
262 }
263 GLOBAL_GPU_POOL.as_ref().unwrap().clone()
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_memory_pool() {
273 let pool = GpuMemoryPool::new();
274 assert_eq!(pool.current_usage(), 0);
275 }
276}